提交 d3d31a58 编写于 作者: Z zhoukunsheng

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into all_any

此差异已折叠。
...@@ -51,9 +51,7 @@ else() ...@@ -51,9 +51,7 @@ else()
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif() endif()
cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU) if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info) cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
...@@ -74,7 +72,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap ...@@ -74,7 +72,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
......
...@@ -11,9 +11,8 @@ ...@@ -11,9 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
...@@ -56,6 +55,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -56,6 +55,7 @@ void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
WaitInputVarGenerated(); WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs()); auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs()); auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -57,7 +57,7 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -57,7 +57,7 @@ struct BroadcastOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; }; bool IsMultiDeviceTransfer() override { return true; };
protected: protected:
void RunImpl() override; void RunImpl() override;
......
...@@ -147,6 +147,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -147,6 +147,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
if (VLOG_IS_ON(2)) {
AppendPass("all_reduce_deps_pass");
}
if (SeqOnlyAllReduceOps(strategy)) { if (SeqOnlyAllReduceOps(strategy)) {
VLOG(10) << "Add all_reduce_deps_pass"; VLOG(10) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass"); AppendPass("all_reduce_deps_pass");
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h"
namespace paddle {
namespace framework {
namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
DataBalanceOpHandle::DataBalanceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) {
for (auto &p : places_) {
this->SetDeviceContext(p, ctxs->DevCtx(p));
}
}
}
#else
DataBalanceOpHandle::DataBalanceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
std::string DataBalanceOpHandle::Name() const { return "data balance"; }
std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
const std::vector<int> &device_sizes) {
int device_num = device_sizes.size();
int total_size = 0;
int empty_num = 0;
std::vector<std::array<int, 2>> size_device_vec;
size_device_vec.reserve(device_num);
for (int i = 0; i < device_num; ++i) {
if (device_sizes[i] == 0) {
++empty_num;
}
total_size += device_sizes[i];
size_device_vec.push_back({{device_sizes[i], i}});
}
std::vector<std::array<int, 3>> res;
if (empty_num == 0) {
// No need to do data balance.
return res;
}
if (total_size < device_num) {
// No enough data.
PADDLE_THROW_EOF();
}
std::sort(size_device_vec.begin(), size_device_vec.end(),
[](const std::array<int, 2> &a, const std::array<int, 2> &b) {
return a[0] > b[0];
});
int expected_device_size = total_size / device_num;
int src_idx = 0;
for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) {
if (size_device_vec[src_idx][0] <= expected_device_size) {
++src_idx;
PADDLE_ENFORCE_LT(
src_idx, device_num - empty_num,
"In current srategy an empty tensor should not be copy source.");
}
size_device_vec[src_idx][0] -= expected_device_size;
size_device_vec[dst_idx][0] += expected_device_size;
res.push_back({{size_device_vec[src_idx][1], size_device_vec[dst_idx][1],
expected_device_size}});
}
return res;
}
void DataBalanceOpHandle::RunImpl() {
PADDLE_ENFORCE_GT(places_.size(), 1UL,
"Data balance can only be enabled when the number of "
"places to run larger than 1.");
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
int data_num = in_var_handles.size() / places_.size();
WaitInputVarGenerated();
std::vector<std::vector<LoDTensor *>> lod_tensors(data_num);
std::vector<int> device_sizes;
for (int i = 0; i < static_cast<int>(in_var_handles.size()); ++i) {
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal.");
int place_idx = i / data_num;
int data_idx = i % data_num;
auto *local_scope =
local_scopes_[place_idx]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name());
PADDLE_ENFORCE(tensor_var->IsType<LoDTensor>());
auto *tensor = tensor_var->GetMutable<LoDTensor>();
lod_tensors[data_idx].push_back(tensor);
int ins_size =
tensor->lod().empty() ? tensor->dims()[0] : tensor->NumElements();
if (data_idx == 0) {
device_sizes.emplace_back(ins_size);
} else {
PADDLE_ENFORCE_EQ(
ins_size, device_sizes.at(place_idx),
"All data on the same device shall have the same batch size.");
}
}
const auto &balance_plan = GetBalancePlan(device_sizes);
for (const auto &trans : balance_plan) {
for (int data_idx = 0; data_idx < data_num; ++data_idx) {
LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]];
LoDTensor *dst_tensor = lod_tensors[data_idx][trans[1]];
int trans_ins_size = trans[2];
LoD src_lod = src_tensor->lod();
int src_ins_size =
src_lod.empty() ? src_tensor->dims()[0] : src_tensor->NumElements();
int cut_point = src_ins_size - trans_ins_size;
if (!src_lod.empty()) {
for (auto &level : src_lod) {
cut_point = level[cut_point];
}
}
TensorCopySync(src_tensor->Slice(cut_point, src_tensor->dims()[0]),
dst_tensor->place(), dst_tensor);
src_tensor->ShareDataWith(src_tensor->Slice(0, cut_point));
if (!src_lod.empty()) {
dst_tensor->set_lod(SliceInLevel(
src_lod, 0, src_ins_size - trans_ins_size, src_ins_size));
src_tensor->set_lod(
SliceInLevel(src_lod, 0, 0, src_ins_size - trans_ins_size));
}
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
namespace details {
struct DataBalanceOpHandle : public OpHandleBase {
public:
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
DataBalanceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
private:
// std::vector<(src_dev_id, dst_dev_id, trans_size)>
std::vector<std::array<int, 3>> GetBalancePlan(
const std::vector<int> &batch_size_per_device);
const std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> places_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -82,6 +82,8 @@ void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) { ...@@ -82,6 +82,8 @@ void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
} }
} }
bool FetchOpHandle::IsMultiDeviceTransfer() { return true; }
std::string FetchOpHandle::Name() const { return "Fetch"; } std::string FetchOpHandle::Name() const { return "Fetch"; }
} // namespace details } // namespace details
......
...@@ -39,6 +39,8 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -39,6 +39,8 @@ struct FetchOpHandle : public OpHandleBase {
std::string Name() const override; std::string Name() const override;
bool IsMultiDeviceTransfer() override;
protected: protected:
void RunImpl() override; void RunImpl() override;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_vars_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
void FuseVarsOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(in_var_handles.size(), 0UL);
PADDLE_ENFORCE_EQ(out_var_handles.size() - 1, inputs_numel_.size(), "");
auto scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto out_var_handle = out_var_handles[0];
auto out_var = scope->Var(out_var_handle->name());
auto out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->Resize({total_numel_}).mutable_data(this->place_, type_);
int64_t s = 0;
for (size_t i = 1; i < out_var_handles.size(); ++i) {
auto out_name = out_var_handles[i]->name();
auto out_t = scope->Var(out_name)->GetMutable<LoDTensor>();
auto numel = this->inputs_numel_.at(out_name);
out_t->ShareDataWith(out_tensor->Slice(s, s + numel));
s += numel;
}
this->RunAndRecordEvent([] {});
}
std::string FuseVarsOpHandle::Name() const { return "fuse vars"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct FuseVarsOpHandle : public OpHandleBase {
public:
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const proto::VarType::Type var_type)
: OpHandleBase(node),
local_scope_(local_scope),
place_(place),
inputs_numel_(inputs_numel),
type_(var_type) {
total_numel_ = 0;
for (auto in_numel : inputs_numel) {
PADDLE_ENFORCE_GT(in_numel.second, 0);
total_numel_ += in_numel.second;
}
}
std::string Name() const override;
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
private:
Scope *local_scope_;
const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_;
const proto::VarType::Type type_;
int64_t total_numel_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -112,19 +112,20 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -112,19 +112,20 @@ void FusedAllReduceOpHandle::RunImpl() {
}); });
for (size_t k = 1; k < g_tensor.size(); ++k) { for (size_t k = 1; k < g_tensor.size(); ++k) {
const void *pre_address = g_tensor.at(k - 1).second->data<void>(); const void *cur_address = g_tensor.at(k - 1).second->data<void>();
int64_t len = g_tensor.at(k - 1).second->numel(); int64_t len = g_tensor.at(k - 1).second->numel();
auto offset = len * framework::SizeOfType(dtype); auto offset = len * framework::SizeOfType(dtype);
void *next_address = reinterpret_cast<void *>( void *infer_next_address = reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(pre_address) + offset); reinterpret_cast<uintptr_t>(cur_address) + offset);
const void *cur_address = g_tensor.at(k).second->data<void>(); const void *next_address = g_tensor.at(k).second->data<void>();
VLOG(10) << k << ", "
<< " pre_address(" << g_tensor.at(k - 1).first VLOG(10) << string::Sprintf(
<< "): " << pre_address << ", cur_address(" "Input[%d](%s) address: 0X%02x, Input[%d](%s) address: 0X%02x, Infer "
<< g_tensor.at(k).first << "): " << cur_address "input[%d] address: 0X%02x. The offset: %d",
<< ", offset:" << offset << ", " << next_address << ", " k - 1, g_tensor.at(k - 1).first, cur_address, g_tensor.at(k).first, k,
<< cur_address; next_address, k, infer_next_address, offset);
PADDLE_ENFORCE_EQ(next_address, cur_address); PADDLE_ENFORCE_EQ(infer_next_address, next_address,
"The address is not consistent.");
} }
} }
......
...@@ -14,13 +14,15 @@ ...@@ -14,13 +14,15 @@
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" #include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include <map> #include <map>
#include <unordered_set>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -41,15 +42,42 @@ OpHandleBase::~OpHandleBase() { ...@@ -41,15 +42,42 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_cuda) { void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda) { if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) {
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id)); PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE( PADDLE_ENFORCE(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming)); cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
} }
if (IsMultiDeviceTransfer() && dev_ctxes_.size() > 0) {
for (auto &out_var : outputs_) {
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
if (out_var_handle) {
int dev_id =
boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
out_var_handle->SetGenerateEvent(events_[dev_id]);
}
}
} else {
PADDLE_ENFORCE_EQ(dev_ctxes_.size(), 1UL,
"%s should have only one dev_ctx.", Name());
auto &place = dev_ctxes_.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
for (auto &out_var : outputs_) {
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
if (out_var_handle) {
PADDLE_ENFORCE(
platform::is_same_place(place, out_var_handle->place()),
"The place of input(%s) is not consistent with the "
"place of current op(%s).",
out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_[dev_id]);
}
}
}
} }
#else #else
PADDLE_ENFORCE(!use_cuda); PADDLE_ENFORCE(!use_cuda);
#endif #endif
...@@ -93,17 +121,48 @@ void OpHandleBase::AddOutput(VarHandleBase *out) { ...@@ -93,17 +121,48 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
void OpHandleBase::WaitInputVarGenerated() { void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) { for (auto in_var : inputs_) {
if (NeedWait(in_var)) { if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) { // Dummy Variable is used to represent dependencies between operators, so
in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second); // there doesn't add event for it.
auto *in_var_handle = dynamic_cast<VarHandle *>(in_var);
if (in_var_handle) {
auto &place = in_var_handle->place();
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
auto stream =
static_cast<platform::CUDADeviceContext *>(dev_ctxes_.at(place))
->stream();
PADDLE_ENFORCE(
cudaStreamWaitEvent(stream, in_var_handle->GetEvent(), 0));
#else
PADDLE_THROW("Doesn't compile the GPU.");
#endif
}
// There are nothing to do when the place is CPUPlace.
} }
} }
} }
} }
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) { for (auto in_var : inputs_) {
if (NeedWait(in)) { if (NeedWait(in_var)) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(place)); // Dummy Variable is used to represent dependencies between operators, so
// there doesn't add event for it.
auto *in_var_handle = dynamic_cast<VarHandle *>(in_var);
if (in_var_handle) {
if (platform::is_gpu_place(in_var_handle->place())) {
#ifdef PADDLE_WITH_CUDA
auto stream = static_cast<platform::CUDADeviceContext *>(
dev_ctxes_.at(in_var_handle->place()))
->stream();
PADDLE_ENFORCE(
cudaStreamWaitEvent(stream, in_var_handle->GetEvent(), 0));
#else
PADDLE_THROW("Doesn't compile the GPU.");
#endif
}
// There are nothing to do when the place is CPUPlace.
}
} }
} }
} }
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -27,62 +26,49 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -27,62 +26,49 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
: graph_(graph), : graph_(graph),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr), : nullptr),
prepare_pool_(1),
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
fetch_ctxs_(places), fetch_ctxs_(places),
running_ops_(0), strategy_(strategy) {
strategy_(strategy) {} PrepareOpDeps();
CopyOpDeps();
}
FeedFetchList ThreadedSSAGraphExecutor::Run( FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::unique_ptr<platform::RecordEvent> event( std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
std::unordered_set<VarHandleBase *> pending_vars; CopyOpDeps();
auto ready_vars = std::make_shared<BlockingQueue<VarHandleBase *>>(); VLOG(10) << "ThreadedSSAGraphExecutor::Run";
std::unordered_set<OpHandleBase *> ready_ops; std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars(
new BlockingQueue<VarHandleBase *>);
auto &pending_ops = op_deps->pending_ops_;
auto &pending_vars = op_deps->pending_vars_;
auto &ready_ops = op_deps->ready_ops_;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple // For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule // streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams. // together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available. // Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, ready_vars.get(), version_pair);
}
}
}
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var);
}
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op);
} else {
InsertPendingOp(&pending_ops, op);
}
}
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<FetchOpHandle *> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::unordered_set<VarHandleBase *> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
&pending_vars, ready_vars.get(), &fetch_data); &pending_ops, &pending_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
running_ops_++;
RunOp(ready_vars, op); RunOp(ready_vars, op);
} }
set.clear(); set.clear();
}; };
auto run_all_op = [&](OpHandleBase *op) { RunOp(ready_vars, op); };
// Clean run context // Clean run context
run_op_futures_.clear(); run_op_futures_.clear();
exception_holder_.Clear(); exception_holder_.Clear();
...@@ -91,19 +77,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -91,19 +77,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
// Keep loop until all vars are ready. // Keep loop until all vars are ready.
// run_all_ops(ready_ops);
// NOTE: DelayedOps have a lower priority. It will be scheduled after all
// ready_ops have been performed.
if (ready_ops.empty() && strategy_.allow_op_delay_ && running_ops_ == 0) {
run_all_ops(delayed_ops);
} else {
run_all_ops(ready_ops);
}
// 2. Find ready variable // 2. Find ready variable
bool timeout; bool timeout;
auto cur_ready_vars = ready_vars->PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
...@@ -115,6 +93,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -115,6 +93,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
continue; continue;
} }
} }
// 3. Remove the dependency of ready_var. // 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var. // Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) { for (auto ready_var : cur_ready_vars) {
...@@ -123,11 +102,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -123,11 +102,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op]; auto &deps = pending_ops[op];
--deps; --deps;
if (deps == 0) { if (deps == 0) {
if (op->IsMultiDeviceTransfer() && strategy_.allow_op_delay_) { run_all_op(op);
delayed_ops.insert(op);
} else {
ready_ops.insert(op);
}
} }
} }
} }
...@@ -143,16 +118,17 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -143,16 +118,17 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) { FeedFetchList *fetch_data) {
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::unordered_set<VarHandleBase *> local_ready_vars;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(*it->second.rbegin()); fetched_vars[fetch_var_name].emplace_back(*it->second.rbegin());
} }
} }
} }
...@@ -161,8 +137,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -161,8 +137,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name); auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program " "Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)"); "is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second; auto &vars = fetched_var_it->second;
...@@ -184,9 +161,23 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -184,9 +161,23 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto *fetch_dummy = new DummyVarHandle(fetch_var); auto *fetch_dummy = new DummyVarHandle(fetch_var);
op->AddOutput(fetch_dummy); op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy); fetch_dependencies->emplace(fetch_dummy);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
this->InsertPendingOp(pending_ops, op); this->InsertPendingVar(pending_vars, &local_ready_vars, fetch_dummy);
size_t wait_input_num = 0;
std::unordered_set<VarHandleBase *> input_set(vars.begin(), vars.end());
for (auto *var : input_set) {
if (pending_vars->count(var)) {
++wait_input_num;
}
}
if (wait_input_num) {
pending_ops->insert({op, wait_input_num});
} else {
ready_ops->insert(static_cast<OpHandleBase *>(op));
}
} }
PADDLE_ENFORCE_EQ(local_ready_vars.size(), 0);
} }
void ThreadedSSAGraphExecutor::InsertPendingOp( void ThreadedSSAGraphExecutor::InsertPendingOp(
...@@ -197,11 +188,63 @@ void ThreadedSSAGraphExecutor::InsertPendingOp( ...@@ -197,11 +188,63 @@ void ThreadedSSAGraphExecutor::InsertPendingOp(
void ThreadedSSAGraphExecutor::InsertPendingVar( void ThreadedSSAGraphExecutor::InsertPendingVar(
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const { std::unordered_set<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
pending_vars->insert(var); pending_vars->insert(var);
if (var->GeneratedOp() == nullptr) { if (var->GeneratedOp() == nullptr) {
ready_vars->Push(var); ready_vars->insert(var);
}
}
void ThreadedSSAGraphExecutor::PrepareOpDeps() {
op_deps_.reset(new OpDependentData());
std::unordered_map<OpHandleBase *, size_t> &pending_ops =
op_deps_->pending_ops_;
std::unordered_set<VarHandleBase *> &pending_vars = op_deps_->pending_vars_;
std::unordered_set<OpHandleBase *> &ready_ops = op_deps_->ready_ops_;
std::unordered_set<VarHandleBase *> ready_vars;
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair);
}
}
}
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var);
}
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op);
} else {
InsertPendingOp(&pending_ops, op);
}
} }
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
}
}
}
}
void ThreadedSSAGraphExecutor::CopyOpDeps() {
op_deps_futures_ = prepare_pool_.enqueue([&] {
auto *op_deps = new OpDependentData();
op_deps->pending_ops_.insert(op_deps_->pending_ops_.begin(),
op_deps_->pending_ops_.end());
op_deps->pending_vars_.insert(op_deps_->pending_vars_.begin(),
op_deps_->pending_vars_.end());
op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(),
op_deps_->ready_ops_.end());
return std::unique_ptr<OpDependentData>(op_deps);
});
} }
void ThreadedSSAGraphExecutor::RunOp( void ThreadedSSAGraphExecutor::RunOp(
...@@ -216,7 +259,6 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -216,7 +259,6 @@ void ThreadedSSAGraphExecutor::RunOp(
op->Run(strategy_.use_cuda_); op->Run(strategy_.use_cuda_);
} }
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << " Signal posted"; VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) { } catch (...) {
......
...@@ -15,18 +15,20 @@ ...@@ -15,18 +15,20 @@
#pragma once #pragma once
#include <deque> #include <deque>
#include <functional>
#include <list> #include <list>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party #include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h" #include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -36,6 +38,12 @@ class Scope; ...@@ -36,6 +38,12 @@ class Scope;
namespace details { namespace details {
struct OpDependentData {
std::unordered_map<OpHandleBase *, size_t> pending_ops_;
std::unordered_set<VarHandleBase *> pending_vars_;
std::unordered_set<OpHandleBase *> ready_ops_;
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor { class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
...@@ -57,29 +65,35 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -57,29 +65,35 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
private: private:
ir::Graph *graph_; ir::Graph *graph_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
::ThreadPool prepare_pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
std::atomic<int> running_ops_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops, void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const; OpHandleBase *op_instance) const;
void InsertPendingVar(std::unordered_set<VarHandleBase *> *pending_vars, void InsertPendingVar(std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, std::unordered_set<VarHandleBase *> *ready_vars,
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps(const std::vector<std::string> &fetch_tensors, void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars,
FeedFetchList *fetch_data); FeedFetchList *fetch_data);
void PrepareOpDeps();
void CopyOpDeps();
private: private:
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::unique_ptr<OpDependentData> op_deps_;
// use std::list because clear(), push_back, and for_each are O(1) // use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_; std::list<std::future<void>> run_op_futures_;
}; };
......
...@@ -43,6 +43,7 @@ struct VarHandleBase { ...@@ -43,6 +43,7 @@ struct VarHandleBase {
virtual ~VarHandleBase(); virtual ~VarHandleBase();
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
virtual const std::string& Name() const = 0;
void AddInput(OpHandleBase* in, ir::Node* node) { void AddInput(OpHandleBase* in, ir::Node* node) {
node_->inputs.clear(); node_->inputs.clear();
...@@ -95,8 +96,6 @@ struct VarHandleBase { ...@@ -95,8 +96,6 @@ struct VarHandleBase {
// //
// NOTE: runtime variables have place. // NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~VarHandle(); virtual ~VarHandle();
std::string DebugString() const override; std::string DebugString() const override;
...@@ -109,6 +108,20 @@ struct VarHandle : public VarHandleBase { ...@@ -109,6 +108,20 @@ struct VarHandle : public VarHandleBase {
name_(std::move(name)), name_(std::move(name)),
place_(std::move(place)) {} place_(std::move(place)) {}
#ifdef PADDLE_WITH_CUDA
bool HasEvent() { return has_event_; }
const cudaEvent_t& GetEvent() {
PADDLE_ENFORCE(HasEvent(), "The event is not set.");
return event_;
}
void SetGenerateEvent(const cudaEvent_t& event) {
has_event_ = true;
event_ = event;
}
#endif
// version field currently is not used, however, just store the version to // version field currently is not used, however, just store the version to
// debug easily. // debug easily.
private: private:
...@@ -116,6 +129,11 @@ struct VarHandle : public VarHandleBase { ...@@ -116,6 +129,11 @@ struct VarHandle : public VarHandleBase {
size_t scope_idx_; size_t scope_idx_;
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
#ifdef PADDLE_WITH_CUDA
// Only when this event is triggered, var is generated.
cudaEvent_t event_;
bool has_event_{false};
#endif
public: public:
bool IsTheSameVar(const VarHandle& o) const { bool IsTheSameVar(const VarHandle& o) const {
...@@ -125,6 +143,7 @@ struct VarHandle : public VarHandleBase { ...@@ -125,6 +143,7 @@ struct VarHandle : public VarHandleBase {
size_t version() const { return version_; } size_t version() const { return version_; }
size_t scope_idx() const { return scope_idx_; } size_t scope_idx() const { return scope_idx_; }
const std::string& Name() const override { return name_; }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
const platform::Place& place() const { return place_; } const platform::Place& place() const { return place_; }
}; };
...@@ -136,6 +155,10 @@ struct DummyVarHandle : public VarHandleBase { ...@@ -136,6 +155,10 @@ struct DummyVarHandle : public VarHandleBase {
virtual ~DummyVarHandle(); virtual ~DummyVarHandle();
std::string DebugString() const override; std::string DebugString() const override;
public:
const std::string& Name() const override { return name_; }
std::string name_{"DummyVar"};
}; };
} // namespace details } // namespace details
......
...@@ -70,6 +70,7 @@ pass_library(conv_affine_channel_fuse_pass inference) ...@@ -70,6 +70,7 @@ pass_library(conv_affine_channel_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
# There may be many transpose-flatten structures in a model, and the output of # There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will # these structures will be used as inputs to the concat Op. This pattern will
......
...@@ -224,8 +224,8 @@ std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl( ...@@ -224,8 +224,8 @@ std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl(
PADDLE_ENFORCE(param_scope()); PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), false /* with_residual_data */);
QuantizeConv(graph.get(), true /* with_residual_data */); QuantizeConv(graph.get(), true /* with_residual_data */);
QuantizeConv(graph.get());
QuantizePool(graph.get()); QuantizePool(graph.get());
return graph; return graph;
......
...@@ -599,10 +599,19 @@ bool VarLinksToOp(Node *node, const std::string &op_type) { ...@@ -599,10 +599,19 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) { bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp()); PADDLE_ENFORCE(op->IsOp());
if (op->Op()->Input(argument).size() <= nth) return false; if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth)
return false;
return var->Name() == op->Op()->Input(argument)[nth]; return var->Name() == op->Op()->Input(argument)[nth];
} }
bool HasInput(Node *op, const std::string &argument) {
PADDLE_ENFORCE(op->IsOp());
auto const &names = op->Op()->InputNames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp()); PADDLE_ENFORCE(op->IsOp());
...@@ -1082,8 +1091,15 @@ PDNode *patterns::Conv::operator()() { ...@@ -1082,8 +1091,15 @@ PDNode *patterns::Conv::operator()() {
PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { PDNode *patterns::ConvResidual::operator()(bool with_residual_data) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
if (!with_residual_data) if (!with_residual_data) {
conv_op->assert_op_attr("fuse_residual_connection", false); conv_op->assert_more([&](Node *x) {
auto node_names = x->Op()->InputNames();
if (!HasInput(x, "ResidualData") ||
x->Op()->Input("ResidualData").size() == 0)
return true;
return false;
});
}
auto input_var = pattern->NewNode(conv_input_repr()) auto input_var = pattern->NewNode(conv_input_repr())
->AsInput() ->AsInput()
......
...@@ -305,6 +305,9 @@ bool VarLinksFromOp(Node* node, const std::string& op_type); ...@@ -305,6 +305,9 @@ bool VarLinksFromOp(Node* node, const std::string& op_type);
// Check whether a var node is a op node's nth input. // Check whether a var node is a op node's nth input.
bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth); bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth);
// Check whether the op node has input of given name.
bool HasInput(Node* op, const std::string& argument);
// Tell whether a var node is a op node's nth output. // Tell whether a var node is a op node's nth output.
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);
......
...@@ -14,12 +14,16 @@ limitations under the License. */ ...@@ -14,12 +14,16 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/*
* Specifies which operators should use MKLDNN.
*/
class MKLDNNPlacementPass : public Pass { class MKLDNNPlacementPass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h"
#include <memory>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> RuntimeContextCachePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(runtime_context_cache_pass,
paddle::framework::ir::RuntimeContextCachePass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RuntimeContextCachePass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -876,7 +876,22 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig( ...@@ -876,7 +876,22 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeContext ctx(Inputs(), Outputs(), scope); if (!HasAttr(kEnableCacheRuntimeContext)) {
RuntimeContext ctx(Inputs(), Outputs(), scope);
RunImpl(scope, place, &ctx);
} else {
const Scope* cur_scope = &scope;
if (!runtime_ctx_ || pre_scope_ != cur_scope) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope;
}
RunImpl(scope, place, runtime_ctx_.get());
}
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place,
RuntimeContext* runtime_ctx) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -891,7 +906,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -891,7 +906,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType( auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -915,8 +930,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -915,8 +930,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = auto* transfer_scope = PrepareData(scope, expected_kernel_key,
PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx); &transfered_inplace_vars, runtime_ctx);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
...@@ -927,13 +942,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -927,13 +942,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs. // not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second( kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx,
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); *runtime_ctx, kernel_configs));
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered. // there is inplace variable has been transfered.
......
...@@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; ...@@ -62,6 +62,14 @@ constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient. /// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@"; constexpr char kNewGradSuffix[] = "@NEWGRAD@";
/// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in name scope.
/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same
/// name scope, since the input/output names of this Op do not change in the
/// execution, RuntimeContext could be created only at the first iteration of
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";
/// If an Op has this attribute, all its kernels should calculate output /// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And /// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape() /// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
...@@ -456,6 +464,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -456,6 +464,8 @@ class OperatorWithKernel : public OperatorBase {
// same. // same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place,
RuntimeContext* runtime_ctx) const;
/** /**
* Transfer data from scope to a transfered scope. If there is no data need to * Transfer data from scope to a transfered scope. If there is no data need to
...@@ -474,6 +484,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -474,6 +484,8 @@ class OperatorWithKernel : public OperatorBase {
protected: protected:
mutable OpKernelConfigsMap kernel_configs_map_; mutable OpKernelConfigsMap kernel_configs_map_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -131,6 +131,15 @@ struct Argument { ...@@ -131,6 +131,15 @@ struct Argument {
// Pass a set of op types to enable its mkldnn kernel // Pass a set of op types to enable its mkldnn kernel
DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes, DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes,
std::unordered_set<std::string>); std::unordered_set<std::string>);
// A set of op types to enable their quantized kernels
DECL_ARGUMENT_FIELD(quantize_enabled_op_types, QuantizeEnabledOpTypes,
std::unordered_set<std::string>);
// A set of op IDs to exclude from enabling their quantized kernels
DECL_ARGUMENT_FIELD(quantize_excluded_op_ids, QuantizeExcludedOpIds,
std::unordered_set<int>);
// Scales for variables to be quantized // Scales for variables to be quantized
DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale); DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -60,6 +61,13 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -60,6 +61,13 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("mkldnn_enabled_op_types", pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>( new std::unordered_set<std::string>(
argument->mkldnn_enabled_op_types())); argument->mkldnn_enabled_op_types()));
} else if (pass_name == "cpu_quantize_placement_pass") {
pass->Set("quantize_enabled_op_types",
new std::unordered_set<std::string>(
argument->quantize_enabled_op_types()));
pass->Set(
"quantize_excluded_op_ids",
new std::unordered_set<int>(argument->quantize_excluded_op_ids()));
} else if (pass_name == "cpu_quantize_pass") { } else if (pass_name == "cpu_quantize_pass") {
pass->Set("quant_var_scales", pass->Set("quant_var_scales",
new VarQuantScale(argument->quant_var_scales())); new VarQuantScale(argument->quant_var_scales()));
......
...@@ -202,6 +202,7 @@ void AnalysisConfig::Update() { ...@@ -202,6 +202,7 @@ void AnalysisConfig::Update() {
// Append after the Affine_channel_conv_fuse pass. // Append after the Affine_channel_conv_fuse pass.
pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
} }
pass_builder()->DeletePass("runtime_context_cache_pass");
} }
if (use_mkldnn_) { if (use_mkldnn_) {
......
...@@ -80,6 +80,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -80,6 +80,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add_act_fuse_pass", //
"conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"runtime_context_cache_pass", //
#endif #endif
}); });
...@@ -90,6 +91,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -90,6 +91,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
use_gpu_ = true; use_gpu_ = true;
} }
void GpuPassStrategy::EnableQuantizer() {
LOG(ERROR) << "GPU not support quantization yet";
}
void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
analysis_passes_.push_back(pass); analysis_passes_.push_back(pass);
} }
...@@ -115,6 +120,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -115,6 +120,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"is_test_pass", // "is_test_pass", //
"identity_scale_op_clean_pass", // "identity_scale_op_clean_pass", //
"runtime_context_cache_pass", //
}); });
use_gpu_ = false; use_gpu_ = false;
} }
......
...@@ -84,6 +84,10 @@ class PassStrategy : public PaddlePassBuilder { ...@@ -84,6 +84,10 @@ class PassStrategy : public PaddlePassBuilder {
*/ */
virtual void EnableMKLDNN() {} virtual void EnableMKLDNN() {}
/** Enable quantize optimization
*/
virtual void EnableQuantizer() {}
bool use_gpu() const { return use_gpu_; } bool use_gpu() const { return use_gpu_; }
virtual ~PassStrategy() = default; virtual ~PassStrategy() = default;
...@@ -124,6 +128,16 @@ class CpuPassStrategy : public PassStrategy { ...@@ -124,6 +128,16 @@ class CpuPassStrategy : public PassStrategy {
use_mkldnn_ = false; use_mkldnn_ = false;
#endif #endif
} }
void EnableQuantizer() override {
if (!use_quantizer_) {
passes_.push_back("cpu_quantize_placement_pass");
}
use_quantizer_ = true;
}
protected:
bool use_quantizer_{false};
}; };
/** The GPU passes strategy, it is used in AnalysisPredictor with GPU mode. /** The GPU passes strategy, it is used in AnalysisPredictor with GPU mode.
...@@ -138,6 +152,7 @@ class GpuPassStrategy : public PassStrategy { ...@@ -138,6 +152,7 @@ class GpuPassStrategy : public PassStrategy {
} }
void EnableMKLDNN() override; void EnableMKLDNN() override;
void EnableQuantizer() override;
virtual ~GpuPassStrategy() = default; virtual ~GpuPassStrategy() = default;
}; };
......
cc_library(benchmark SRCS benchmark.cc DEPS enforce) cc_library(benchmark SRCS benchmark.cc DEPS enforce)
cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark) cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark)
cc_binary(visualizer SRCS visualizer.cc DEPS analysis
paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/utils/visualizer.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <fstream>
#include <memory>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/platform/init.h"
DEFINE_string(model_dir, "", "model directory");
DEFINE_string(model_program_path, "", "model program path");
DEFINE_string(model_params_path, "", "model params path");
using paddle::inference::analysis::Argument;
namespace paddle {
namespace inference {
namespace utils {
void Visualizer::SetArgument(Argument *argument) { argument_ = argument; }
bool Visualizer::Run() {
paddle::framework::InitDevices(false);
paddle::inference::analysis::Analyzer().Run(argument_);
return true;
}
} // namespace utils
} // namespace inference
} // namespace paddle
// Generate a dot file describing the structure of graph.
// To use this tool, run command: ./visualizer [options...]
// Options:
// --model_dir: the directory of model
// --model_program_path: the path of program
// --model_params_path: the path of params
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
paddle::inference::analysis::Argument argument;
argument.SetUseGPU(false);
argument.SetUseTensorRT(false);
if (FLAGS_model_dir.empty()) {
if (FLAGS_model_program_path.empty() || FLAGS_model_params_path.empty()) {
LOG(ERROR) << "Please set model_dir"
" or model_program_path and model_params_path";
return -1;
} else {
argument.SetModelProgramPath(FLAGS_model_program_path);
argument.SetModelParamsPath(FLAGS_model_params_path);
}
} else {
argument.SetModelDir(FLAGS_model_dir);
}
// Only 1 pass, default filename is 0_ir_origin.dot
// For more details, looking for paddle::inference::analysis::IRPassManager
argument.SetIrAnalysisPasses({"infer_clean_graph_pass", "graph_viz_pass"});
std::unique_ptr<paddle::framework::Scope> scope{
new paddle::framework::Scope()};
argument.SetScopeNotOwned(
const_cast<paddle::framework::Scope *>(scope.get()));
paddle::inference::utils::Visualizer visualizer;
visualizer.SetArgument(&argument);
visualizer.Run();
return 0;
}
USE_PASS(infer_clean_graph_pass);
USE_PASS(graph_viz_pass);
USE_PASS(graph_to_program_pass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/inference/analysis/argument.h"
namespace paddle {
namespace inference {
namespace utils {
using paddle::inference::analysis::Argument;
class Visualizer final {
public:
Visualizer() = default;
~Visualizer() = default;
Visualizer(const Visualizer &) = delete;
Visualizer &operator=(const Visualizer &) = delete;
void SetArgument(Argument *);
bool Run();
private:
Argument *argument_;
};
} // namespace utils
} // namespace inference
} // namespace paddle
...@@ -61,4 +61,6 @@ nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocat ...@@ -61,4 +61,6 @@ nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocat
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator best_fit_allocator locked_allocator cpu_allocator) cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator best_fit_allocator locked_allocator cpu_allocator)
cc_test(allocator_facade_test SRCS allocator_facade_test.cc DEPS allocator_facade) cc_test(allocator_facade_abs_flags_test SRCS allocator_facade_abs_flags_test.cc DEPS allocator_facade)
cc_test(allocator_facade_frac_flags_test SRCS allocator_facade_frac_flags_test.cc DEPS allocator_facade)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_int64(gpu_allocator_retry_time);
#endif
namespace paddle {
namespace memory {
namespace allocation {
//! Run allocate test cases for different places
void AllocateTestCases() {
auto &instance = AllocatorFacade::Instance();
platform::Place place;
size_t size = 1024;
{
place = platform::CPUPlace();
size = 1024;
auto cpu_allocation = instance.Alloc(place, size);
ASSERT_NE(cpu_allocation, nullptr);
ASSERT_NE(cpu_allocation->ptr(), nullptr);
ASSERT_EQ(cpu_allocation->place(), place);
ASSERT_EQ(cpu_allocation->size(), size);
}
#ifdef PADDLE_WITH_CUDA
{
place = platform::CUDAPlace(0);
size = 1024;
auto gpu_allocation = instance.Alloc(place, size);
ASSERT_NE(gpu_allocation, nullptr);
ASSERT_NE(gpu_allocation->ptr(), nullptr);
ASSERT_EQ(gpu_allocation->place(), place);
ASSERT_GE(gpu_allocation->size(), size);
}
{
// Allocate 2GB gpu memory
place = platform::CUDAPlace(0);
size = 2 * static_cast<size_t>(1 << 30);
auto gpu_allocation = instance.Alloc(place, size);
ASSERT_NE(gpu_allocation, nullptr);
ASSERT_NE(gpu_allocation->ptr(), nullptr);
ASSERT_EQ(gpu_allocation->place(), place);
ASSERT_GE(gpu_allocation->size(), size);
}
{
place = platform::CUDAPinnedPlace();
size = (1 << 20);
auto cuda_pinned_allocation =
instance.Alloc(platform::CUDAPinnedPlace(), 1 << 20);
ASSERT_NE(cuda_pinned_allocation, nullptr);
ASSERT_NE(cuda_pinned_allocation->ptr(), nullptr);
ASSERT_EQ(cuda_pinned_allocation->place(), place);
ASSERT_GE(cuda_pinned_allocation->size(), size);
}
#endif
}
TEST(Allocator, SpecifyGpuMemory) {
#ifdef PADDLE_WITH_CUDA
// Set to 0.0 to test FLAGS_initial_gpu_memory_in_mb and
// FLAGS_reallocate_gpu_memory_in_mb
FLAGS_fraction_of_gpu_memory_to_use = 0.0;
// 512 MB
FLAGS_initial_gpu_memory_in_mb = 512;
// 4 MB
FLAGS_reallocate_gpu_memory_in_mb = 4;
FLAGS_gpu_allocator_retry_time = 500;
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
AllocateTestCases();
}
} // namespace allocation
} // namespace memory
} // namespace paddle
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_double(fraction_of_cuda_pinned_memory_to_use); DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_int64(gpu_allocator_retry_time); DECLARE_int64(gpu_allocator_retry_time);
#endif #endif
...@@ -26,13 +28,8 @@ namespace paddle { ...@@ -26,13 +28,8 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
TEST(allocator, allocator) { //! Run allocate test cases for different places
#ifdef PADDLE_WITH_CUDA void AllocateTestCases() {
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
FLAGS_gpu_allocator_retry_time = 500;
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
auto &instance = AllocatorFacade::Instance(); auto &instance = AllocatorFacade::Instance();
platform::Place place; platform::Place place;
size_t size = 1024; size_t size = 1024;
...@@ -82,6 +79,16 @@ TEST(allocator, allocator) { ...@@ -82,6 +79,16 @@ TEST(allocator, allocator) {
#endif #endif
} }
TEST(Allocator, Allocator) {
#ifdef PADDLE_WITH_CUDA
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
FLAGS_gpu_allocator_retry_time = 500;
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
AllocateTestCases();
}
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -37,6 +37,8 @@ DEFINE_bool(init_allocated_mem, false, ...@@ -37,6 +37,8 @@ DEFINE_bool(init_allocated_mem, false,
"that initializing the allocated memory with a small value " "that initializing the allocated memory with a small value "
"during unit testing."); "during unit testing.");
DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
namespace paddle { namespace paddle {
...@@ -153,12 +155,18 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) { ...@@ -153,12 +155,18 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
platform::GpuMinChunkSize(), platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize()); platform::GpuMaxChunkSize());
VLOG(10) << "\n\nNOTE: each GPU device use " VLOG(10) << "\n\nNOTE:\n"
<< FLAGS_fraction_of_gpu_memory_to_use * 100 << "You can set GFlags environment variable "
<< "% of GPU memory.\n" << "'FLAGS_fraction_of_gpu_memory_to_use' "
<< "You can set GFlags environment variable '" << "or 'FLAGS_initial_gpu_memory_in_mb' "
<< "FLAGS_fraction_of_gpu_memory_to_use" << "or 'FLAGS_reallocate_gpu_memory_in_mb' "
<< "' to change the fraction of GPU usage.\n\n"; << "to change the memory size for GPU usage.\n"
<< "Current 'FLAGS_fraction_of_gpu_memory_to_use' value is "
<< FLAGS_fraction_of_gpu_memory_to_use
<< ". Current 'FLAGS_initial_gpu_memory_in_mb' value is "
<< FLAGS_initial_gpu_memory_in_mb
<< ". Current 'FLAGS_reallocate_gpu_memory_in_mb' value is "
<< FLAGS_reallocate_gpu_memory_in_mb << "\n\n";
} }
}); });
......
...@@ -9,3 +9,5 @@ endif(${WITH_GPU}) ...@@ -9,3 +9,5 @@ endif(${WITH_GPU})
cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator) cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator)
cc_library(buddy_allocator SRCS buddy_allocator.cc DEPS memory_block system_allocator glog) cc_library(buddy_allocator SRCS buddy_allocator.cc DEPS memory_block system_allocator glog)
cc_test(buddy_allocator_test SRCS buddy_allocator_test.cc DEPS buddy_allocator)
...@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/memory/detail/buddy_allocator.h" #include "paddle/fluid/memory/detail/buddy_allocator.h"
#include <algorithm>
#include <utility>
#include "glog/logging.h" #include "glog/logging.h"
DEFINE_bool(free_idle_memory, false, DEFINE_bool(free_idle_memory, false,
...@@ -36,9 +40,10 @@ BuddyAllocator::~BuddyAllocator() { ...@@ -36,9 +40,10 @@ BuddyAllocator::~BuddyAllocator() {
"have actually been freed"; "have actually been freed";
while (!pool_.empty()) { while (!pool_.empty()) {
auto block = static_cast<MemoryBlock*>(std::get<2>(*pool_.begin())); auto block = static_cast<MemoryBlock*>(std::get<2>(*pool_.begin()));
VLOG(10) << "Free from block (" << block << ", " << max_chunk_size_ << ")"; VLOG(10) << "Free from block (" << block << ", " << block->size(cache_)
<< ")";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); system_allocator_->Free(block, block->size(cache_), block->index(cache_));
cache_.invalidate(block); cache_.invalidate(block);
pool_.erase(pool_.begin()); pool_.erase(pool_.begin());
} }
...@@ -71,7 +76,7 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) { ...@@ -71,7 +76,7 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) {
// refill the pool if failure // refill the pool if failure
if (it == pool_.end()) { if (it == pool_.end()) {
it = RefillPool(); it = RefillPool(size);
// if still failure, fail fatally // if still failure, fail fatally
if (it == pool_.end()) { if (it == pool_.end()) {
return nullptr; return nullptr;
...@@ -184,19 +189,28 @@ void* BuddyAllocator::SystemAlloc(size_t size) { ...@@ -184,19 +189,28 @@ void* BuddyAllocator::SystemAlloc(size_t size) {
return static_cast<MemoryBlock*>(p)->data(); return static_cast<MemoryBlock*>(p)->data();
} }
BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool(
size_t request_bytes) {
size_t allocate_bytes = max_chunk_size_;
size_t index = 0;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (system_allocator_->UseGpu()) { if (system_allocator_->UseGpu()) {
if ((total_used_ + total_free_) == 0) { if ((total_used_ + total_free_) == 0) {
// Compute the maximum allocation size for the first allocation. // Compute the allocation size for gpu for the first allocation.
max_chunk_size_ = platform::GpuMaxChunkSize(); allocate_bytes = std::max(platform::GpuInitAllocSize(), request_bytes);
} else {
// Reallocation size
if (realloc_size_ == 0) {
realloc_size_ = platform::GpuReallocSize();
}
allocate_bytes = std::max(realloc_size_, request_bytes);
} }
} }
#endif #endif
// Allocate a new maximum sized block // Allocate a new block
size_t index = 0; void* p = system_allocator_->Alloc(&index, allocate_bytes);
void* p = system_allocator_->Alloc(&index, max_chunk_size_);
if (p == nullptr) return pool_.end(); if (p == nullptr) return pool_.end();
...@@ -204,7 +218,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { ...@@ -204,7 +218,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
<< " from system allocator"; << " from system allocator";
static_cast<MemoryBlock*>(p)->init(&cache_, MemoryBlock::FREE_CHUNK, index, static_cast<MemoryBlock*>(p)->init(&cache_, MemoryBlock::FREE_CHUNK, index,
max_chunk_size_, nullptr, nullptr); allocate_bytes, nullptr, nullptr);
// gpu fallback allocation // gpu fallback allocation
if (system_allocator_->UseGpu() && if (system_allocator_->UseGpu() &&
...@@ -212,10 +226,10 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { ...@@ -212,10 +226,10 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
fallback_alloc_count_++; fallback_alloc_count_++;
} }
total_free_ += max_chunk_size_; total_free_ += allocate_bytes;
// dump the block into pool // dump the block into pool
return pool_.insert(IndexSizeAddress(index, max_chunk_size_, p)).first; return pool_.insert(IndexSizeAddress(index, allocate_bytes, p)).first;
} }
BuddyAllocator::PoolSet::iterator BuddyAllocator::FindExistChunk(size_t size) { BuddyAllocator::PoolSet::iterator BuddyAllocator::FindExistChunk(size_t size) {
...@@ -286,12 +300,12 @@ void BuddyAllocator::CleanIdleFallBackAlloc() { ...@@ -286,12 +300,12 @@ void BuddyAllocator::CleanIdleFallBackAlloc() {
VLOG(10) << "Return block " << block << " to fallback allocator."; VLOG(10) << "Return block " << block << " to fallback allocator.";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); system_allocator_->Free(block, block->size(cache_), block->index(cache_));
cache_.invalidate(block); cache_.invalidate(block);
pool = PoolSet::reverse_iterator(pool_.erase(std::next(pool).base())); pool = PoolSet::reverse_iterator(pool_.erase(std::next(pool).base()));
total_free_ -= max_chunk_size_; total_free_ -= block->size(cache_);
fallback_alloc_count_--; fallback_alloc_count_--;
// If no fall allocation exists, return directly // If no fall allocation exists, return directly
...@@ -322,12 +336,12 @@ void BuddyAllocator::CleanIdleNormalAlloc() { ...@@ -322,12 +336,12 @@ void BuddyAllocator::CleanIdleNormalAlloc() {
VLOG(10) << "Return block " << block << " to base allocator."; VLOG(10) << "Return block " << block << " to base allocator.";
system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); system_allocator_->Free(block, block->size(cache_), block->index(cache_));
cache_.invalidate(block); cache_.invalidate(block);
pool = PoolSet::reverse_iterator(pool_.erase(std::next(pool).base())); pool = PoolSet::reverse_iterator(pool_.erase(std::next(pool).base()));
total_free_ -= max_chunk_size_; total_free_ -= block->size(cache_);
if (!shall_free_alloc()) return; if (!shall_free_alloc()) return;
} }
......
...@@ -60,7 +60,7 @@ class BuddyAllocator { ...@@ -60,7 +60,7 @@ class BuddyAllocator {
void* SystemAlloc(size_t size); void* SystemAlloc(size_t size);
/*! \brief If existing chunks are not suitable, refill pool */ /*! \brief If existing chunks are not suitable, refill pool */
PoolSet::iterator RefillPool(); PoolSet::iterator RefillPool(size_t request_bytes);
/** /**
* \brief Find the suitable chunk from existing pool and split * \brief Find the suitable chunk from existing pool and split
...@@ -89,6 +89,8 @@ class BuddyAllocator { ...@@ -89,6 +89,8 @@ class BuddyAllocator {
size_t min_chunk_size_; // the minimum size of each chunk size_t min_chunk_size_; // the minimum size of each chunk
size_t max_chunk_size_; // the maximum size of each chunk size_t max_chunk_size_; // the maximum size of each chunk
size_t realloc_size_ = 0; // the size of re-allocated chunk
private: private:
/** /**
* \brief A list of free allocation * \brief A list of free allocation
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/detail/buddy_allocator.h"
#include <memory>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/gpu_info.h"
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
#endif
namespace paddle {
namespace memory {
namespace detail {
constexpr static int test_gpu_id = 0;
void TestBuddyAllocator(BuddyAllocator* allocator, size_t size_bytes) {
bool freed = false;
size_t used_bytes = allocator->Used();
if (size_bytes > 0) {
void* p = allocator->Alloc(size_bytes);
EXPECT_NE(p, nullptr);
#ifdef PADDLE_WITH_CUDA
if (size_bytes < platform::GpuMaxChunkSize()) {
#else
if (size_bytes < platform::CpuMaxChunkSize()) {
#endif
// Not allocate from SystemAllocator
EXPECT_GE(allocator->Used(), used_bytes + size_bytes);
} else {
// Allocate from SystemAllocator doesn't count in Used()
EXPECT_EQ(allocator->Used(), used_bytes);
}
int* intp = static_cast<int*>(p);
std::shared_ptr<int> ptr(intp, [&](void* p) {
allocator->Free(intp);
freed = true;
});
} else {
freed = true;
}
EXPECT_EQ(used_bytes, allocator->Used());
EXPECT_TRUE(freed);
}
#ifdef PADDLE_WITH_CUDA
TEST(BuddyAllocator, GpuFraction) {
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
BuddyAllocator buddy_allocator(
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
TestBuddyAllocator(&buddy_allocator, 10);
TestBuddyAllocator(&buddy_allocator, 10 << 10);
TestBuddyAllocator(&buddy_allocator, 10 << 20);
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
}
TEST(BuddyAllocator, InitRealloc) {
FLAGS_initial_gpu_memory_in_mb = 100;
FLAGS_reallocate_gpu_memory_in_mb = 50;
EXPECT_EQ(platform::GpuMaxChunkSize(), static_cast<size_t>(100 << 20));
BuddyAllocator buddy_allocator(
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
// Less then initial size and reallocate size
TestBuddyAllocator(&buddy_allocator, 10 << 20);
// Between initial size and reallocate size and not exceed pool
TestBuddyAllocator(&buddy_allocator, 80 << 20);
// Less then reallocate size and exceed pool
TestBuddyAllocator(&buddy_allocator, 40 << 20);
// Greater then reallocate size and exceed pool
TestBuddyAllocator(&buddy_allocator, 80 << 20);
// Greater then initial size and reallocate size
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
}
TEST(BuddyAllocator, ReallocSizeGreaterThanInit) {
FLAGS_initial_gpu_memory_in_mb = 5;
FLAGS_reallocate_gpu_memory_in_mb = 10;
EXPECT_EQ(platform::GpuMaxChunkSize(), static_cast<size_t>(10 << 20));
BuddyAllocator buddy_allocator(
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
// Less then initial size and reallocate size
TestBuddyAllocator(&buddy_allocator, 1 << 20);
// Between initial size and reallocate size and not exceed pool
TestBuddyAllocator(&buddy_allocator, 3 << 20);
// Less then initial size and exceed pool
TestBuddyAllocator(&buddy_allocator, 3 << 20);
// Less then reallocate size and not exceed pool (now pool is 15 MB, used 7
// MB)
TestBuddyAllocator(&buddy_allocator, 7 << 20);
// Less then reallocate size and exceed pool
TestBuddyAllocator(&buddy_allocator, 8 << 20);
// Greater then initial size and reallocate size
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
}
#endif
} // namespace detail
} // namespace memory
} // namespace paddle
...@@ -32,6 +32,9 @@ limitations under the License. */ ...@@ -32,6 +32,9 @@ limitations under the License. */
DECLARE_bool(use_pinned_memory); DECLARE_bool(use_pinned_memory);
DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace detail { namespace detail {
...@@ -119,11 +122,18 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) { ...@@ -119,11 +122,18 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) {
gpu_alloc_size_ += size; gpu_alloc_size_ += size;
return p; return p;
} else { } else {
LOG(WARNING) LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0
<< "Cannot malloc " << size / 1024.0 / 1024.0 << " MB GPU memory. Please shrink "
<< " MB GPU memory. Please shrink FLAGS_fraction_of_gpu_memory_to_use " "FLAGS_fraction_of_gpu_memory_to_use or "
"environment variable to a lower value. Current value is " "FLAGS_initial_gpu_memory_in_mb or "
<< FLAGS_fraction_of_gpu_memory_to_use; "FLAGS_reallocate_gpu_memory_in_mb"
"environment variable to a lower value. "
<< "Current FLAGS_fraction_of_gpu_memory_to_use value is "
<< FLAGS_fraction_of_gpu_memory_to_use
<< ". Current FLAGS_initial_gpu_memory_in_mb value is "
<< FLAGS_initial_gpu_memory_in_mb
<< ". Current FLAGS_reallocate_gpu_memory_in_mb value is "
<< FLAGS_reallocate_gpu_memory_in_mb;
return nullptr; return nullptr;
} }
} }
......
...@@ -76,12 +76,16 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -76,12 +76,16 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const std::string& name) { const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_CUDA // FIXME(liuwei1031) temporarily disable the code to unblock users
auto it1 = oper.Attrs().find("use_cudnn"); // TODO(liuwei1031) figure out the reason behind
if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) { // https://github.com/PaddlePaddle/Paddle/issues/16096
library = framework::LibraryType::kCUDNN; // and re-enable this in the future
} // #ifdef PADDLE_WITH_CUDA
#endif // auto it1 = oper.Attrs().find("use_cudnn");
// if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
// library = framework::LibraryType::kCUDNN;
// }
// #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn"); auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
...@@ -188,6 +192,9 @@ $$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ ...@@ -188,6 +192,9 @@ $$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
UNUSED constexpr char SqrtDoc[] = R"DOC( UNUSED constexpr char SqrtDoc[] = R"DOC(
Sqrt Activation Operator. Sqrt Activation Operator.
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.
$out = \sqrt{x}$ $out = \sqrt{x}$
)DOC"; )DOC";
......
...@@ -67,6 +67,22 @@ class AffineChannelOp : public framework::OperatorWithKernel { ...@@ -67,6 +67,22 @@ class AffineChannelOp : public framework::OperatorWithKernel {
"Input(Bias) of AffineChannelOp should not be null."); "Input(Bias) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of AffineChannelOp should not be null."); "Output(Out) of AffineChannelOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto scale_dims = ctx->GetInputDim("Scale");
auto b_dims = ctx->GetInputDim("Bias");
const framework::DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int64_t C = (data_layout == framework::DataLayout::kNCHW
? x_dims[1]
: x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
PADDLE_ENFORCE_EQ(scale_dims[0], C);
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
PADDLE_ENFORCE_EQ(b_dims[0], C);
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
...@@ -97,6 +113,27 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { ...@@ -97,6 +113,27 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
} }
}; };
class AffineChannelGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("affine_channel_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Scale", Input("Scale"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
template <typename T> template <typename T>
using EigenArrayMap = using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>; Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
...@@ -244,8 +281,7 @@ namespace ops = paddle::operators; ...@@ -244,8 +281,7 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp, REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp,
ops::AffineChannelOpMaker, ops::AffineChannelOpMaker, ops::AffineChannelGradMaker);
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad); REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad);
REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>, REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>,
......
...@@ -57,7 +57,7 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class ConcatOp : public framework::OperatorWithKernel {
"elements except the specify axis."); "elements except the specify axis.");
} else { } else {
// not check -1 with other in compile time // not check -1 with other in compile time
if (out_dims[j] != -1 && ins[i][j] != -1) { if (out_dims[j] > 0 && ins[i][j] > 0) {
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same " "Input tensors should have the same "
"elements except the specify axis."); "elements except the specify axis.");
......
...@@ -71,7 +71,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,7 +71,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"1. downgrade_in_infer(default), downgrade the outcome at inference " "1. downgrade_in_infer(default), downgrade the outcome at inference "
"time" "time"
" train: out = input * mask" " train: out = input * mask"
" inference: out = input * dropout_prob" " inference: out = input * (1.0 - dropout_prob)"
"2. upscale_in_train, upscale the outcome at training time, do nothing " "2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference" "in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )" " train: out = input * mask / ( 1.0 - dropout_prob )"
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseFloorDivOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "FloorDiv"; }
std::string GetEquation() const override { return "Out = X // Y"; }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(elementwise_floordiv, ops::ElementwiseOp,
ops::ElementwiseFloorDivOpMaker);
REGISTER_OP_CPU_KERNEL(
elementwise_floordiv,
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext,
int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_floordiv,
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
template <typename T>
struct FloorDivFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a / b; }
};
template <typename DeviceContext, typename T>
void elementwise_floor_div(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<FloorDivFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, FloorDivFunctor<T>(), z);
}
template <typename DeviceContext, typename T>
class ElementwiseFloorDivKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
// dtype of x and y is int64 or int32
elementwise_floor_div<DeviceContext, T>(ctx, x, y, z);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseModOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Mod"; }
std::string GetEquation() const override { return "Out = X % Y"; }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(elementwise_mod, ops::ElementwiseOp,
ops::ElementwiseModOpMaker);
REGISTER_OP_CPU_KERNEL(
elementwise_mod,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
template <typename T>
struct ModFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a % b; }
};
template <typename DeviceContext, typename T>
void elementwise_mod(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctor<T>(), z);
}
template <typename DeviceContext, typename T>
class ElementwiseModKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
// dtype of x and y is int64 or int32
elementwise_mod<DeviceContext, T>(ctx, x, y, z);
}
};
} // namespace operators
} // namespace paddle
...@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> { ...@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template <typename T>
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) {
if (scale_num == 1) {
const int channel = in->dims()[0];
const T* scale_factor = scales[0]->data<T>();
for (int i = 0; i < channel; i++) {
T s = scale_factor[i];
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (s / max_range) * in_e;
}
} else if (scale_num == 2) {
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
for (int i = 0; i < batch_size; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int j = 0; j < channel; j++) {
T s = scale_one[j];
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (s * scale_two[0] / max_range) * in_e;
}
}
}
}
};
template struct DequantizeFunctor<platform::CPUDeviceContext, float>; template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>; template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, float>;
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, double>;
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public: public:
......
...@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> { ...@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template <typename T>
__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range,
int num, int channel, T* out) {
int tid = threadIdx.x;
int channel_size = num / channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
for (int i = tid; i < channel_size; i += blockDim.x) {
out_c[i] = in_c[i] * scale[blockIdx.x] / max_range;
}
}
template <typename T>
__global__ void DequantizeTwoScale(const T* in, const T* scale_one,
const T* scale_two, T max_range, int num,
int batch_size, int channel, T* out) {
int tid = threadIdx.x;
int channel_size = num / (batch_size * channel);
int scale_index = blockIdx.x % channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
for (int i = tid; i < channel_size; i += blockDim.x) {
out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range;
}
}
template <typename T>
struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) {
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) {
int num = in->numel();
int channel = in->dims()[0];
const T* scale_factor = scales[0]->data<T>();
int block = 1024;
int grid = channel;
DequantizeOneScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, channel, out_data);
} else if (scale_num == 2) {
int num = in->numel();
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
int block = 1024;
int grid = batch_size * channel;
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_one, scale_two, max_range, num, batch_size, channel,
out_data);
}
}
};
template struct DequantizeFunctor<platform::CUDADeviceContext, float>; template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
template struct DequantizeFunctor<platform::CUDADeviceContext, double>; template struct DequantizeFunctor<platform::CUDADeviceContext, double>;
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, float>;
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -28,6 +29,13 @@ struct DequantizeFunctor { ...@@ -28,6 +29,13 @@ struct DequantizeFunctor {
framework::Tensor* out); framework::Tensor* out);
}; };
template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num,
T max_range, framework::Tensor* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> { class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public: public:
...@@ -54,32 +62,33 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -54,32 +62,33 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto scales = ctx.MultiInput<framework::Tensor>("Scales"); auto scales = ctx.MultiInput<framework::Tensor>("Scales");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0],
"The number of first scale values must be the same with "
"first dimension value of Input(X).");
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits"); auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
int max_range = std::pow(2, quant_bits[0] - 1) - 1; int max_range = 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
int scale_num = scales.size();
auto dequant = DequantizeFunctor<DeviceContext, T>(); if (scale_num == 1) {
for (int64_t i = 0; i < in->dims()[0]; i++) { PADDLE_ENFORCE_EQ(
framework::Tensor one_channel_in = in->Slice(i, i + 1); scales[0]->numel(), in->dims()[0],
framework::Tensor one_channel_out = out->Slice(i, i + 1); "The number of first scale values must be the same with "
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); "first dimension value of Input(X) when the `Scales` has only one "
dequant(dev_ctx, &one_channel_in, &one_channel_scale, "element.");
static_cast<T>(max_range), &one_channel_out); max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} } else if (scale_num == 2) {
PADDLE_ENFORCE_EQ(
if (scales.size() == 2) { scales[0]->numel(), in->dims()[1],
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales[1]->numel(), 1, scales[1]->numel(), 1,
"The second scale tensor should only have one value at now."); "The second scale tensor should only have one value at now.");
max_range = std::pow(2, quant_bits[1] - 1) - 1; max_range *= (std::pow(2, quant_bits[0] - 1) - 1) *
dequant(dev_ctx, out, scales[1], static_cast<T>(max_range), out); (std::pow(2, quant_bits[1] - 1) - 1);
} }
ChannelDequantizeFunctor<DeviceContext, T>()(
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range), out);
} }
}; };
......
...@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>; template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
const int num, const int channel, T* out) {
const int channel_size = num / channel;
for (int i = 0; i < channel; i++) {
auto* start = in + i * channel_size;
auto* end = in + (i + 1) * channel_size;
out[i] = std::abs(*(std::max_element(start, end, Compare<T>())));
}
}
};
template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext& ctx,
...@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>; template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel,
framework::Tensor* out) {
auto* scale_data = scale.data<T>();
auto* in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
const int channel_size = in.numel() / channel;
platform::Transform<platform::CPUDeviceContext> trans;
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
}
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
}
}
};
template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float>;
template <typename T> template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext& ctx,
...@@ -169,10 +214,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -169,10 +214,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
ctx->HasOutput("Out"), ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null."); "Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput("OutScales"), ctx->HasOutput("OutScale"),
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null."); "Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]}); ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -192,7 +237,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -192,7 +237,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
AddOutput("Out", AddOutput("Out",
"(Tensor) Output of quantized low level tensor, " "(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."); "but also saved as float data type.");
AddOutput("OutScales", "(Tensor) Current channel wise scale"); AddOutput("OutScale", "(Tensor) Current channel wise scale");
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int& bit_length) {
......
...@@ -74,6 +74,45 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -74,6 +74,45 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>; template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c,
T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
extern __shared__ T shared_max_data[];
shared_max_data[tid] = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = fabs(in_c[i]);
if (tmp > shared_max_data[tid]) {
shared_max_data[tid] = tmp;
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
shared_max_data[tid] = shared_max_data[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[blockIdx.x] = shared_max_data[0];
}
}
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* in,
const int num, const int channel, T* out) {
int block = 1024;
int grid = channel;
FindChannelAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
in, num, channel, out);
}
};
template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
__global__ void ClipAndQuantKernel(const T* in, const T* scale, __global__ void ClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n, T* out) { const int bin_cnt, const int n, T* out) {
...@@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, ...@@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
T s = scale[0]; T s = scale[0];
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[bid]; T x = in[i];
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt / s * v; v = bin_cnt / s * v;
out[bid] = round(v); out[i] = round(v);
} }
} }
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n,
const int c, T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
out_c[i] = round(v);
}
}
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel,
framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = channel;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ChannelClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, channel, out_data);
}
};
template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
float>;
template <typename T> template <typename T>
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
const T* last_scale, const T* last_scale,
...@@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext,
float>; float>;
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor { ...@@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor {
framework::Tensor* scales_arr, framework::Tensor* out_scale); framework::Tensor* scales_arr, framework::Tensor* out_scale);
}; };
template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num,
const int channel, T* out);
};
template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
const int channel, framework::Tensor* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor { struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum, void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum,
...@@ -78,29 +91,18 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -78,29 +91,18 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
auto* out_scales = context.Output<framework::Tensor>("OutScales"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scales_data = out_scales->mutable_data<T>(context.GetPlace()); T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto find_abs_max = FindAbsMaxFunctor<DeviceContext, T>(); FindChannelAbsMaxFunctor<DeviceContext, T>()(
for (int64_t i = 0; i < in->dims()[0]; i++) { dev_ctx, in->data<T>(), in->numel(), in->dims()[0], out_scale_data);
framework::Tensor one_channel = in->Slice(i, i + 1); ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
const T* one_channel_data = one_channel.data<T>(); dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out);
find_abs_max(dev_ctx, one_channel_data, one_channel.numel(),
&out_scales_data[i]);
}
auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1);
clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt,
&one_channel_out);
}
} }
}; };
......
...@@ -11,89 +11,27 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,89 +11,27 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fstream>
#include "paddle/fluid/framework/data_type_transform.h" #include <string>
#include "paddle/fluid/framework/op_registry.h" #include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/load_combine_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class LoadCombineOp : public framework::OperatorBase { class LoadCombineOp : public framework::OperatorWithKernel {
public: public:
LoadCombineOp(const std::string &type, using framework::OperatorWithKernel::OperatorWithKernel;
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, void InferShape(framework::InferShapeContext *ctx) const override {}
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} protected:
framework::OpKernelType GetExpectedKernelType(
private: const framework::ExecutionContext &ctx) const override {
void RunImpl(const framework::Scope &scope, framework::OpKernelType kt = framework::OpKernelType(
const platform::Place &place) const override { framework::proto::VarType::FP32, ctx.GetPlace());
auto filename = Attr<std::string>("file_path"); return kt;
auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto model_from_memory = Attr<bool>("model_from_memory");
auto out_var_names = Outputs("Out");
PADDLE_ENFORCE_GT(
static_cast<int>(out_var_names.size()), 0,
"The number of output variables should be greater than 0.");
if (!model_from_memory) {
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin),
"Cannot open file %s for load_combine op", filename);
LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
} else {
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
std::stringstream fin(filename, std::ios::in | std::ios::binary);
LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
}
}
void LoadParamsFromBuffer(
const framework::Scope &scope, const platform::Place &place,
std::istream *buffer, bool load_as_fp16,
const std::vector<std::string> &out_var_names) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < out_var_names.size(); i++) {
auto *out_var = scope.FindVar(out_var_names[i]);
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_names[i]);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
// Error checking
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
// Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, dev_ctx);
auto in_dtype = tensor->type();
auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
buffer->peek();
PADDLE_ENFORCE(buffer->eof(),
"You are not allowed to load partial data via "
"load_combine_op, use load_op instead.");
} }
}; };
...@@ -124,21 +62,30 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -124,21 +62,30 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
LoadCombine Operator. LoadCombine Operator.
LoadCombine operator loads LoDTensor variables from a file, which could be LoadCombine operator loads LoDTensor variables from a file, which could be
loaded in memory already. The file should contain one or more LoDTensors loaded in memory already. The file should contain one or more LoDTensors
serialized using the SaveCombine operator. The serialized using the SaveCombine operator. The
LoadCombine operator applies a deserialization strategy to appropriately load LoadCombine operator applies a deserialization strategy to appropriately load
the LodTensors, and this strategy complements the serialization strategy used the LodTensors, and this strategy complements the serialization strategy used
in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled
with the SaveCombine operator, and can only deserialize one or more LoDTensors with the SaveCombine operator, and can only deserialize one or more LoDTensors
that were saved using the SaveCombine operator. that were saved using the SaveCombine operator.
)DOC"); )DOC");
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(load_combine, ops::LoadCombineOp, REGISTER_OPERATOR(load_combine, ops::LoadCombineOp,
ops::LoadCombineOpProtoMaker); ops::LoadCombineOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
load_combine,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
load_combine,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LoadCombineOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
auto model_from_memory = ctx.Attr<bool>("model_from_memory");
auto &out_var_names = ctx.Outputs("Out");
PADDLE_ENFORCE_GT(
static_cast<int>(out_var_names.size()), 0,
"The number of output variables should be greater than 0.");
if (!model_from_memory) {
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin),
"Cannot open file %s for load_combine op", filename);
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
} else {
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
std::stringstream fin(filename, std::ios::in | std::ios::binary);
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
}
}
void LoadParamsFromBuffer(
const framework::ExecutionContext &context, const platform::Place &place,
std::istream *buffer, bool load_as_fp16,
const std::vector<std::string> &out_var_names) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto out_vars = context.MultiOutputVar("Out");
for (size_t i = 0; i < out_var_names.size(); i++) {
PADDLE_ENFORCE(out_vars[i] != nullptr,
"Output variable %s cannot be found", out_var_names[i]);
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
// Error checking
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
// Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, dev_ctx);
auto in_dtype = tensor->type();
auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
out_vars[i]->Clear();
tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
buffer->peek();
PADDLE_ENFORCE(buffer->eof(),
"You are not allowed to load partial data via "
"load_combine_op, use load_op instead.");
}
};
} // namespace operators
} // namespace paddle
...@@ -11,89 +11,26 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,89 +11,26 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fstream>
#include "paddle/fluid/framework/data_type_transform.h" #include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/operators/load_op.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class LoadOp : public framework::OperatorBase { class LoadOp : public framework::OperatorWithKernel {
public: public:
LoadOp(const std::string &type, const framework::VariableNameMap &inputs, using framework::OperatorWithKernel::OperatorWithKernel;
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
filename);
auto out_var_name = Output("Out"); void InferShape(framework::InferShapeContext *ctx) const override {}
auto *out_var = scope.FindVar(out_var_name);
PADDLE_ENFORCE(out_var != nullptr,
"Output variable %s cannot be found in scope %p",
out_var_name, &scope);
if (out_var->IsType<framework::LoDTensor>()) { protected:
LoadLodTensor(fin, place, out_var); framework::OpKernelType GetExpectedKernelType(
} else if (out_var->IsType<framework::SelectedRows>()) { const framework::ExecutionContext &ctx) const override {
LoadSelectedRows(fin, place, out_var); framework::OpKernelType kt = framework::OpKernelType(
} else { framework::proto::VarType::FP32, platform::CPUPlace());
PADDLE_ENFORCE( return kt;
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
out_var_name);
}
}
void LoadLodTensor(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
var->Clear();
tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
} }
}; };
...@@ -116,8 +53,15 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -116,8 +53,15 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"file."); "file.");
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker); REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
load, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/load_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
load, ops::LoadOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <string>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LoadOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = ctx.Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
filename);
auto out_var_name = ctx.Outputs("Out").data();
auto *out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found ",
out_var_name);
PADDLE_ENFORCE(out_var != nullptr, "Output variable cannot be found ");
if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var, ctx);
} else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
out_var_name);
}
}
void LoadLodTensor(std::istream &fin, const platform::Place &place,
framework::Variable *var,
const framework::ExecutionContext &ctx) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
var->Clear();
tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
}
};
} // namespace operators
} // namespace paddle
...@@ -32,7 +32,10 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -32,7 +32,10 @@ class LoDResetOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT(level0.size(), 1, PADDLE_ENFORCE_GT(level0.size(), 1,
"If Input(Y) not provided, the target lod should be " "If Input(Y) not provided, the target lod should be "
"specified by attribute `target_lod`."); "specified by attribute `target_lod`.");
} else {
ctx->ShareLoD("Y", "Out");
} }
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
} }
......
...@@ -78,12 +78,6 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -78,12 +78,6 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"The numel of 'pad_value' can only be 1 or be equal to the " "The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'."); "'step_width'.");
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
pad_tensor->Resize(pad_tensor_dims);
return;
}
const int kBlockSize = 512; const int kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements, /* At least use 32 threads to copy sequence_width elements,
...@@ -129,12 +123,13 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -129,12 +123,13 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout); step_width, layout);
/*
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) { if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor); TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
seq_tensor->Resize(seq_tensor_dims); seq_tensor->Resize(seq_tensor_dims);
return; return;
} }
*/
const int kBlockSize = 512; const int kBlockSize = 512;
......
...@@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel {
context->Attrs().Get<bool>("transpose_Y")); context->Attrs().Get<bool>("transpose_Y"));
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || if (context->IsRuntime()) {
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0); PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
}
std::vector<int64_t> dim_out; std::vector<int64_t> dim_out;
if (mat_dim_x.batch_size_ != 0) { if (mat_dim_x.batch_size_ != 0) {
dim_out = framework::vectorize(dim_x); dim_out = framework::vectorize(dim_x);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,15 +39,20 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) { ...@@ -38,15 +39,20 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
} }
static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
const mkldnn::engine& engine) { const mkldnn::engine& engine,
constexpr auto data_type = mkldnn::memory::f32; const memory::data_type& dt) {
const auto dims = paddle::framework::vectorize2int(input.dims()); const auto dims = paddle::framework::vectorize2int(input.dims());
const auto format = input.format(); const auto format = input.format();
auto description = memory::desc(dims, data_type, format); auto description = memory::desc(dims, dt, format);
auto mem_prim_desc = memory::primitive_desc(description, engine); auto mem_prim_desc = memory::primitive_desc(description, engine);
return mem_prim_desc; return mem_prim_desc;
} }
static mkldnn::memory::format GetDstMemFormat(
const concat::primitive_desc& concat_pd) {
return (memory::format)concat_pd.dst_primitive_desc().desc().data.format;
}
static platform::CPUPlace GetCpuPlace( static platform::CPUPlace GetCpuPlace(
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -61,14 +67,30 @@ static const mkldnn::engine& GetMKLDNNEngine( ...@@ -61,14 +67,30 @@ static const mkldnn::engine& GetMKLDNNEngine(
return dev_ctx.GetEngine(); return dev_ctx.GetEngine();
} }
std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const std::vector<const Tensor*> multi_input,
const int64_t& concat_axis, const memory::data_type& dt) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
for (size_t i = 0; i < multi_input.size(); i++) {
platform::MKLDNNHandler::AppendKeyDims(
&key, paddle::framework::vectorize2int(multi_input[i]->dims()));
}
platform::MKLDNNHandler::AppendKey(&key, std::to_string(concat_axis));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Out"));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
return key;
}
template <typename T> template <typename T>
class ConcatPrimitiveFactory { class ConcatPrimitiveFactory {
public: public:
concat::primitive_desc CreateConcatPrimDescriptor( concat::primitive_desc CreateConcatPrimDescriptor(
const std::vector<const Tensor*> multi_input, Tensor* output, const std::vector<const Tensor*> multi_input, Tensor* output,
int concat_axis, const mkldnn::engine& mkldnn_engine) { int concat_axis, const mkldnn::engine& mkldnn_engine,
CreateSourcesDescriptors(multi_input, mkldnn_engine); const memory::data_type& dt = memory::data_type::f32) {
auto dst_desc = CreateDstMemDescriptor(output); CreateSourcesDescriptors(multi_input, mkldnn_engine, dt);
auto dst_desc = CreateDstMemDescriptor(output, dt);
return concat::primitive_desc(dst_desc, concat_axis, srcs_pd); return concat::primitive_desc(dst_desc, concat_axis, srcs_pd);
} }
...@@ -79,23 +101,39 @@ class ConcatPrimitiveFactory { ...@@ -79,23 +101,39 @@ class ConcatPrimitiveFactory {
return concat(concat_pd, inputs, dst_mem.get()); return concat(concat_pd, inputs, dst_mem.get());
} }
void SetSrcDataHandleByIndex(const std::vector<memory>& srcs, const size_t& i,
void* handler) {
srcs[i].set_data_handle(handler);
}
void SetDstDataHandle(const memory& dst_mem, void* handler) {
dst_mem.set_data_handle(handler);
}
std::vector<memory> GetSrcs() { return srcs; }
memory GetDst() { return dst_mem.get(); }
private: private:
memory::desc CreateDstMemDescriptor(Tensor* output) { memory::desc CreateDstMemDescriptor(Tensor* output,
const memory::data_type& dt) {
auto dst_dims = paddle::framework::vectorize2int(output->dims()); auto dst_dims = paddle::framework::vectorize2int(output->dims());
return memory::desc(dst_dims, platform::MKLDNNGetDataType<T>(), return memory::desc(dst_dims, dt, memory::format::any);
memory::format::any);
} }
mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd, mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd,
Tensor* output, platform::CPUPlace place) { Tensor* output,
const platform::CPUPlace& place) {
return memory(concat_pd.dst_primitive_desc(), return memory(concat_pd.dst_primitive_desc(),
output->mutable_data<T>(place)); output->mutable_data<T>(place));
} }
void CreateSourcesDescriptors(const std::vector<const Tensor*> multi_input, void CreateSourcesDescriptors(const std::vector<const Tensor*> multi_input,
const mkldnn::engine& mkldnn_engine) { const mkldnn::engine& mkldnn_engine,
const memory::data_type& dt) {
for (size_t i = 0; i < multi_input.size(); i++) { for (size_t i = 0; i < multi_input.size(); i++) {
auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); auto mem_prim_desc =
CreateMemPrimDesc(*multi_input[i], mkldnn_engine, dt);
srcs_pd.push_back(mem_prim_desc); srcs_pd.push_back(mem_prim_desc);
srcs.push_back( srcs.push_back(
memory(mem_prim_desc, to_void_cast(multi_input[i]->data<T>()))); memory(mem_prim_desc, to_void_cast(multi_input[i]->data<T>())));
...@@ -120,21 +158,59 @@ template <typename T> ...@@ -120,21 +158,59 @@ template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto place = GetCpuPlace(ctx);
const auto& mkldnn_engine = GetMKLDNNEngine(ctx);
auto multi_input = ctx.MultiInput<Tensor>("X"); auto multi_input = ctx.MultiInput<Tensor>("X");
EnforceLayouts(multi_input); EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
int64_t concat_axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t concat_axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
auto place = GetCpuPlace(ctx);
memory::data_type dt =
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
ConcatPrimitiveFactory<T> prim_creator; ConcatPrimitiveFactory<T> prim_creator;
auto concat_pd = prim_creator.CreateConcatPrimDescriptor( std::string key = CreateKey(ctx, multi_input, concat_axis, dt);
multi_input, output, static_cast<int>(concat_axis), mkldnn_engine); const std::string key_prim = key + "@concat_p";
auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place); const std::string key_concat_pd = key + "@concat_pd";
stream(stream::kind::eager).submit({concat}).wait(); const std::string key_srcs = key + "@concat_srcs";
const std::string key_dst = key + "@concat_dst";
std::shared_ptr<concat::primitive_desc> concat_pd;
std::shared_ptr<std::vector<memory>> srcs;
std::shared_ptr<memory> dst_mem;
auto concat_p = std::static_pointer_cast<concat>(dev_ctx.GetBlob(key_prim));
if (concat_p == nullptr) {
const auto& mkldnn_engine = dev_ctx.GetEngine();
concat_pd = std::make_shared<concat::primitive_desc>(
prim_creator.CreateConcatPrimDescriptor(multi_input, output,
static_cast<int>(concat_axis),
mkldnn_engine, dt));
concat_p = std::make_shared<concat>(
prim_creator.CreateConcatPrimitive(*concat_pd, output, place));
srcs = std::make_shared<std::vector<memory>>(prim_creator.GetSrcs());
dst_mem = std::make_shared<memory>(prim_creator.GetDst());
dev_ctx.SetBlob(key_prim, concat_p);
dev_ctx.SetBlob(key_concat_pd, concat_pd);
dev_ctx.SetBlob(key_srcs, srcs);
dev_ctx.SetBlob(key_dst, dst_mem);
} else {
srcs = std::static_pointer_cast<std::vector<memory>>(
dev_ctx.GetBlob(key_srcs));
dst_mem = std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_dst));
concat_pd = std::static_pointer_cast<concat::primitive_desc>(
dev_ctx.GetBlob(key_concat_pd));
for (size_t i = 0; i < multi_input.size(); i++) {
prim_creator.SetSrcDataHandleByIndex(
*srcs, i, to_void_cast<T>(multi_input[i]->data<T>()));
}
prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data<T>(place));
}
stream(stream::kind::eager).submit({*concat_p}).wait();
output->set_mkldnn_prim_desc(concat_pd.dst_primitive_desc()); output->set_mkldnn_prim_desc(concat_pd->dst_primitive_desc());
} }
}; };
} // namespace operators } // namespace operators
...@@ -143,4 +219,6 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -143,4 +219,6 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(concat, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(concat, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConcatMKLDNNOpKernel<float>) ops::ConcatMKLDNNOpKernel<float>,
ops::ConcatMKLDNNOpKernel<int8_t>,
ops::ConcatMKLDNNOpKernel<uint8_t>);
...@@ -325,7 +325,8 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, ...@@ -325,7 +325,8 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
const bool is_output = outputs.find(var_name) != outputs.end(); const bool is_output = outputs.find(var_name) != outputs.end();
if (!is_output && if (!is_output &&
std::find(var_in_.begin(), var_in_.end(), var_name) == std::find(var_in_.begin(), var_in_.end(), var_name) ==
var_in_.end()) { var_in_.end() &&
scope_.FindVar(var_name)) {
// fill var_in here to keep lhs and rhs order // fill var_in here to keep lhs and rhs order
this->var_in_.emplace_back(var_name); this->var_in_.emplace_back(var_name);
} }
......
...@@ -27,13 +27,9 @@ namespace paddle { ...@@ -27,13 +27,9 @@ namespace paddle {
namespace operators { namespace operators {
namespace ngraphs { namespace ngraphs {
void BuildCrossEntropyNode( std::shared_ptr<ngraph::Node> GetCrossEntropy(
const std::shared_ptr<paddle::framework::OperatorBase>& op, std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
std::shared_ptr< const bool is_soft_label, int ignore_index) {
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto label_shape = label->get_shape(); auto label_shape = label->get_shape();
auto x_shape = x->get_shape(); auto x_shape = x->get_shape();
auto label_rank = label_shape.size(); auto label_rank = label_shape.size();
...@@ -46,18 +42,16 @@ void BuildCrossEntropyNode( ...@@ -46,18 +42,16 @@ void BuildCrossEntropyNode(
label_2d = paddle::platform::NgReshaper(label, label_2d_shape); label_2d = paddle::platform::NgReshaper(label, label_2d_shape);
} }
if (x_rank > 2) { if (x_rank > 2) {
x_2d_shape = paddle::platform::FlattenTo2d(x_shape, x_rank - 1); x_2d_shape = platform::FlattenTo2d(x_shape, x_rank - 1);
x_2d = paddle::platform::NgReshaper(x, x_2d_shape); x_2d = platform::NgReshaper(x, x_2d_shape);
} }
auto batch_size = x_2d_shape.at(0); auto batch_size = x_2d_shape.at(0);
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
std::shared_ptr<ngraph::Node> node_1_hot = label_2d; std::shared_ptr<ngraph::Node> node_1_hot = label_2d;
if (!is_soft_label) { if (!is_soft_label) {
auto label_1d = paddle::platform::NgReshaper( auto label_1d =
label_2d, ngraph::Shape{label_2d_shape.at(0)}); platform::NgReshaper(label_2d, ngraph::Shape{label_2d_shape.at(0)});
node_1_hot = std::make_shared<ngraph::op::OneHot>(label_1d, x_2d_shape, 1); node_1_hot = std::make_shared<ngraph::op::OneHot>(label_1d, x_2d_shape, 1);
} }
if (x->get_element_type() != node_1_hot->get_element_type()) { if (x->get_element_type() != node_1_hot->get_element_type()) {
...@@ -76,11 +70,9 @@ void BuildCrossEntropyNode( ...@@ -76,11 +70,9 @@ void BuildCrossEntropyNode(
auto node_sum = auto node_sum =
std::make_shared<ngraph::op::Sum>(node_mul, ngraph::AxisSet{1}); std::make_shared<ngraph::op::Sum>(node_mul, ngraph::AxisSet{1});
auto node_neg = std::make_shared<ngraph::op::Negative>(node_sum); auto node_neg = std::make_shared<ngraph::op::Negative>(node_sum);
auto xe = auto xe = platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1});
paddle::platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1});
if (!is_soft_label) { if (!is_soft_label) {
auto ignore_index = op_attrs.Get<int>("ignore_index");
auto ignore_node = ngraph::op::Constant::create( auto ignore_node = ngraph::op::Constant::create(
label->get_element_type(), label_2d_shape, {ignore_index}); label->get_element_type(), label_2d_shape, {ignore_index});
auto not_equal_node = auto not_equal_node =
...@@ -89,21 +81,13 @@ void BuildCrossEntropyNode( ...@@ -89,21 +81,13 @@ void BuildCrossEntropyNode(
xe->get_element_type()); xe->get_element_type());
xe = xe * mask; xe = xe * mask;
} }
return xe;
paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map);
} }
void BuildCrossEntropyGradNode( std::shared_ptr<ngraph::Node> GetCrossEntropyGrad(
const std::shared_ptr<paddle::framework::OperatorBase>& op, std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
std::shared_ptr< std::shared_ptr<ngraph::Node> dy, const bool is_soft_label,
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> int ignore_index) {
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto dy = paddle::platform::GetInputNode(op, "Y@GRAD", ngb_node_map);
auto x_shape = x->get_shape(); auto x_shape = x->get_shape();
auto rank = x_shape.size(); auto rank = x_shape.size();
...@@ -111,9 +95,8 @@ void BuildCrossEntropyGradNode( ...@@ -111,9 +95,8 @@ void BuildCrossEntropyGradNode(
if (!is_soft_label) { if (!is_soft_label) {
auto label_shape = label->get_shape(); auto label_shape = label->get_shape();
label_shape.pop_back(); label_shape.pop_back();
label = paddle::platform::NgReshaper(label, label_shape); label = platform::NgReshaper(label, label_shape);
auto ignore_index = op_attrs.Get<int>("ignore_index");
auto ignore_node = ngraph::op::Constant::create( auto ignore_node = ngraph::op::Constant::create(
label->get_element_type(), label_shape, {ignore_index}); label->get_element_type(), label_shape, {ignore_index});
auto not_equal_node = auto not_equal_node =
...@@ -128,7 +111,7 @@ void BuildCrossEntropyGradNode( ...@@ -128,7 +111,7 @@ void BuildCrossEntropyGradNode(
auto dy_shape = dy->get_shape(); auto dy_shape = dy->get_shape();
dy_shape.pop_back(); dy_shape.pop_back();
auto dy_reshape = paddle::platform::NgReshaper(dy, dy_shape); auto dy_reshape = platform::NgReshaper(dy, dy_shape);
auto dy_bcast = std::make_shared<ngraph::op::Broadcast>( auto dy_bcast = std::make_shared<ngraph::op::Broadcast>(
dy_reshape, x_shape, ngraph::AxisSet{rank - 1}); dy_reshape, x_shape, ngraph::AxisSet{rank - 1});
if (x->get_element_type() != label->get_element_type()) { if (x->get_element_type() != label->get_element_type()) {
...@@ -140,7 +123,35 @@ void BuildCrossEntropyGradNode( ...@@ -140,7 +123,35 @@ void BuildCrossEntropyGradNode(
if (!is_soft_label) { if (!is_soft_label) {
xe_grad = xe_grad * mask; xe_grad = xe_grad * mask;
} }
return xe_grad;
}
void BuildCrossEntropyNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
int ignore_index = op_attrs.Get<int>("ignore_index");
auto xe = GetCrossEntropy(x, label, is_soft_label, ignore_index);
paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map);
}
void BuildCrossEntropyGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
int ignore_index = op_attrs.Get<int>("ignore_index");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto dy = paddle::platform::GetInputNode(op, "Y@GRAD", ngb_node_map);
auto xe_grad = GetCrossEntropyGrad(x, label, dy, is_soft_label, ignore_index);
paddle::platform::SetOutputNode(op, "X@GRAD", xe_grad, ngb_node_map); paddle::platform::SetOutputNode(op, "X@GRAD", xe_grad, ngb_node_map);
} }
} // namespace ngraphs } // namespace ngraphs
......
...@@ -27,12 +27,7 @@ namespace paddle { ...@@ -27,12 +27,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace ngraphs { namespace ngraphs {
void BuildSoftmaxNode( std::shared_ptr<ngraph::Node> GetSoftmax(std::shared_ptr<ngraph::Node> x) {
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto x_shape = x->get_shape(); auto x_shape = x->get_shape();
int rank = x_shape.size(); int rank = x_shape.size();
auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, rank - 1); auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, rank - 1);
...@@ -47,16 +42,11 @@ void BuildSoftmaxNode( ...@@ -47,16 +42,11 @@ void BuildSoftmaxNode(
-64., x_shifted); -64., x_shifted);
auto softmax = auto softmax =
std::make_shared<ngraph::op::Softmax>(x_clipped, ngraph::AxisSet{1}); std::make_shared<ngraph::op::Softmax>(x_clipped, ngraph::AxisSet{1});
paddle::platform::SetOutputNode(op, "Out", softmax, ngb_node_map); return softmax;
} }
void BuildSoftmaxGradNode( std::shared_ptr<ngraph::Node> GetSoftmaxGrad(
const std::shared_ptr<paddle::framework::OperatorBase>& op, std::shared_ptr<ngraph::Node> out, std::shared_ptr<ngraph::Node> dout) {
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map);
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto out_shape = out->get_shape(); auto out_shape = out->get_shape();
int rank = out_shape.size(); int rank = out_shape.size();
auto out_2d_shape = paddle::platform::FlattenTo2d(out_shape, rank - 1); auto out_2d_shape = paddle::platform::FlattenTo2d(out_shape, rank - 1);
...@@ -70,6 +60,27 @@ void BuildSoftmaxGradNode( ...@@ -70,6 +60,27 @@ void BuildSoftmaxGradNode(
auto node_bcast = std::make_shared<ngraph::op::Broadcast>( auto node_bcast = std::make_shared<ngraph::op::Broadcast>(
node_sum, out_2d_shape, ngraph::AxisSet{1}); node_sum, out_2d_shape, ngraph::AxisSet{1});
auto dx = (dout - node_bcast) * out; auto dx = (dout - node_bcast) * out;
return dx;
}
void BuildSoftmaxNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto softmax = GetSoftmax(x);
paddle::platform::SetOutputNode(op, "Out", softmax, ngb_node_map);
}
void BuildSoftmaxGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto out = paddle::platform::GetInputNode(op, "Out", ngb_node_map);
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto dx = GetSoftmaxGrad(out, dout);
paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map);
} }
} // namespace ngraphs } // namespace ngraphs
......
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/operators/ngraph/ops/cross_entropy_op.h"
#include "paddle/fluid/operators/ngraph/ops/softmax_op.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
void BuildSoftmaxWithCrossEntropyNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto logits = paddle::platform::GetInputNode(op, "Logits", ngb_node_map);
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto softmax = paddle::operators::ngraphs::GetSoftmax(logits);
auto op_attrs = framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
int ignore_index = op_attrs.Get<int>("ignore_index");
auto xe = paddle::operators::ngraphs::GetCrossEntropy(
softmax, label, is_soft_label, ignore_index);
paddle::platform::SetOutputNode(op, "Softmax", softmax, ngb_node_map);
paddle::platform::SetOutputNode(op, "Loss", xe, ngb_node_map);
}
void BuildSoftmaxWithCrossEntropyGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = framework::AttrReader(op->Attrs());
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
auto softmax = paddle::platform::GetInputNode(op, "Softmax", ngb_node_map);
auto loss_grad =
paddle::platform::GetInputNode(op, "Loss@GRAD", ngb_node_map);
auto softmax_shape = softmax->get_shape();
auto rank = softmax_shape.size();
if (!is_soft_label) {
auto label_shape = label->get_shape();
label_shape.pop_back();
label = platform::NgReshaper(label, label_shape);
label =
std::make_shared<ngraph::op::OneHot>(label, softmax_shape, rank - 1);
}
auto loss_grad_shape = loss_grad->get_shape();
loss_grad_shape.pop_back();
auto loss_grad_reshape = platform::NgReshaper(loss_grad, loss_grad_shape);
auto loss_grad_bcast = std::make_shared<ngraph::op::Broadcast>(
loss_grad_reshape, softmax_shape, ngraph::AxisSet{rank - 1});
if (softmax->get_element_type() != label->get_element_type()) {
label = std::make_shared<ngraph::op::Convert>(label,
softmax->get_element_type());
}
auto logits_grad = loss_grad_bcast * (softmax - label);
paddle::platform::SetOutputNode(op, "Logits@GRAD", logits_grad, ngb_node_map);
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
REGISTER_NG_OP(softmax_with_cross_entropy, BuildSoftmaxWithCrossEntropyNode);
REGISTER_NG_OP(softmax_with_cross_entropy_grad,
BuildSoftmaxWithCrossEntropyGradNode);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/range_op.h"
namespace paddle {
namespace operators {
class RangeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->HasInput("Start")) {
auto s_dims = ctx->GetInputDim("Start");
PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1),
"The shape of Input(Start) should be [1].");
}
if (ctx->HasInput("End")) {
auto e_dims = ctx->GetInputDim("End");
PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1),
"The shape of Input(End) should be [1].");
}
if (ctx->HasInput("Step")) {
auto step_dims = ctx->GetInputDim("Step");
PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1),
"The shape of Input(Step) should be [1].");
}
ctx->SetOutputDim("Out", {-1});
}
};
class RangeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Start",
"Start of interval. The interval includes this value. It is a "
"tensor with shape=[1].");
AddInput("End",
"End of interval. The interval does not include this value, "
"except in some cases where step is not an integer and floating "
"point round-off affects the length of out. It is a tensor with "
"shape=[1].");
AddInput("Step", "Spacing between values. It is a tensor with shape=[1].");
AddOutput("Out", "A sequence of numbers.");
AddComment(R"DOC(
Return evenly spaced values within a given interval. Values are generated within the half-open interval [start, stop) (in other words, the interval including start but excluding stop). Like arange function of numpy.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(range, ops::RangeOp, ops::RangeOpMaker);
REGISTER_OP_CPU_KERNEL(range, ops::CPURangeKernel<int>,
ops::CPURangeKernel<float>, ops::CPURangeKernel<double>,
ops::CPURangeKernel<int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void RangeKernel(T start, T step, int64_t size, T* out) {
CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
}
template <typename T>
class CUDARangeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* start_t = context.Input<framework::Tensor>("Start");
auto* end_t = context.Input<framework::Tensor>("End");
auto* step_t = context.Input<framework::Tensor>("Step");
auto* out = context.Output<framework::Tensor>("Out");
framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0];
framework::TensorCopy(*end_t, platform::CPUPlace(), &n);
T end = n.data<T>()[0];
framework::TensorCopy(*step_t, platform::CPUPlace(), &n);
T step = n.data<T>()[0];
int64_t size = 0;
GetSize(start, end, step, &size);
out->Resize(framework::make_ddim({size}));
T* out_data = out->mutable_data<T>(context.GetPlace());
auto stream = context.cuda_device_context().stream();
int block = 512;
int grid = (size + block - 1) / block;
RangeKernel<T><<<grid, block, 0, stream>>>(start, step, size, out_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(range, ops::CUDARangeKernel<int>,
ops::CUDARangeKernel<int64_t>,
ops::CUDARangeKernel<float>,
ops::CUDARangeKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename T>
void GetSize(T start, T end, T step, int64_t* size) {
PADDLE_ENFORCE(!std::equal_to<T>()(step, 0),
"The step of range op should not be 0.");
PADDLE_ENFORCE(((start < end) && (step > 0)) || ((start > end) && (step < 0)),
"The step should be greater than 0 while start < end. And the "
"step should be less than 0 while start > end.");
*size = std::is_integral<T>::value
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step))
: std::ceil(std::abs((end - start) / step));
}
template <typename T>
class CPURangeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
T start = context.Input<framework::Tensor>("Start")->data<T>()[0];
T end = context.Input<framework::Tensor>("End")->data<T>()[0];
T step = context.Input<framework::Tensor>("Step")->data<T>()[0];
auto* out = context.Output<framework::Tensor>("Out");
int64_t size = 0;
GetSize(start, end, step, &size);
out->Resize(framework::make_ddim({size}));
T* out_data = out->mutable_data<T>(context.GetPlace());
T value = start;
for (int64_t i = 0; i < size; ++i) {
out_data[i] = value;
value += step;
}
}
};
} // namespace operators
} // namespace paddle
...@@ -12,87 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,87 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdint.h> #include <string>
#include <fstream>
#include <numeric> #include "paddle/fluid/operators/save_combine_op.h"
#include <sstream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SaveCombineOp : public framework::OperatorBase { class SaveCombineOp : public framework::OperatorWithKernel {
public: public:
SaveCombineOp(const std::string &type, using framework::OperatorWithKernel::OperatorWithKernel;
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
auto save_as_fp16 = Attr<bool>("save_as_fp16");
bool is_present = FileExists(filename);
if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < inp_var_names.size(); i++) { void InferShape(framework::InferShapeContext *ctx) const override {}
auto *var = scope.FindVar(inp_var_names[i]);
PADDLE_ENFORCE(var != nullptr,
"Cannot find variable %s for save_combine_op",
inp_var_names[i]);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"SaveCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);
auto &tensor = var->Get<framework::LoDTensor>();
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
auto in_dtype = tensor.type();
auto out_dtype =
save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
}
fout.close();
}
}; };
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...@@ -105,7 +36,7 @@ class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,7 +36,7 @@ class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
SaveCombine operator SaveCombine operator
This operator will serialize and write a list of input LoDTensor variables This operator will serialize and write a list of input LoDTensor variables
to a file on disk. to a file on disk.
)DOC"); )DOC");
AddAttr<bool>("overwrite", AddAttr<bool>("overwrite",
...@@ -134,3 +65,10 @@ namespace ops = paddle::operators; ...@@ -134,3 +65,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(save_combine, ops::SaveCombineOp, REGISTER_OPERATOR(save_combine, ops::SaveCombineOp,
ops::SaveCombineOpProtoMaker); ops::SaveCombineOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
save_combine,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/save_combine_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
save_combine,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <sstream>
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SaveCombineOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
bool is_present = FileExists(filename);
if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto &inp_var_names = ctx.Inputs("X");
auto &inp_vars = ctx.MultiInputVar("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < inp_var_names.size(); i++) {
PADDLE_ENFORCE(inp_vars[i] != nullptr,
"Cannot find variable %s for save_combine_op",
inp_var_names[i]);
PADDLE_ENFORCE(inp_vars[i]->IsType<framework::LoDTensor>(),
"SaveCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);
auto &tensor = inp_vars[i]->Get<framework::LoDTensor>();
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
auto in_dtype = tensor.type();
auto out_dtype =
save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
}
fout.close();
}
};
} // namespace operators
} // namespace paddle
...@@ -19,8 +19,8 @@ limitations under the License. */ ...@@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
USE_NO_KERNEL_OP(save_combine); USE_CPU_ONLY_OP(save_combine);
USE_NO_KERNEL_OP(load_combine); USE_CPU_ONLY_OP(load_combine);
template <typename T, typename U> template <typename T, typename U>
T* CreateForSaveCombineOp(int x, int y, const std::vector<int>& lod_info, T* CreateForSaveCombineOp(int x, int y, const std::vector<int>& lod_info,
......
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
USE_NO_KERNEL_OP(save); USE_CPU_ONLY_OP(save);
USE_NO_KERNEL_OP(load); USE_CPU_ONLY_OP(load);
TEST(SaveLoadOp, CPU) { TEST(SaveLoadOp, CPU) {
paddle::framework::Scope scope; paddle::framework::Scope scope;
......
...@@ -15,118 +15,24 @@ limitations under the License. */ ...@@ -15,118 +15,24 @@ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <fstream> #include <fstream>
#include <numeric> #include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/save_op.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SaveOp : public framework::OperatorWithKernel {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
class SaveOp : public framework::OperatorBase {
public: public:
SaveOp(const std::string &type, const framework::VariableNameMap &inputs, using framework::OperatorWithKernel::OperatorWithKernel;
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto iname = Input("X");
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
iname);
if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(place, var);
} else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(scope, place, var);
} else {
PADDLE_ENFORCE(
false,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
}
void SaveLodTensor(const platform::Place &place, void InferShape(framework::InferShapeContext *ctx) const override {}
framework::Variable *var) const {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto save_as_fp16 = Attr<bool>("save_as_fp16");
auto in_dtype = tensor.type();
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
fout.close();
}
void SaveSelectedRows(const framework::Scope &scope, protected:
const platform::Place &place, framework::OpKernelType GetExpectedKernelType(
framework::Variable *var) const { const framework::ExecutionContext &ctx) const override {
auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
PADDLE_ENFORCE( ctx.GetPlace());
lt_var != nullptr,
"Can not find variable kLookupTablePath for SaveSelectedRows");
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str());
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close();
} }
}; };
...@@ -154,14 +60,20 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file ...@@ -154,14 +60,20 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
"The \"file_path\" where the variable will be saved.") "The \"file_path\" where the variable will be saved.")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddOutput(LOOKUP_TABLE_PATH,
"(string)"
"for pserver: The \"kLookupTablePath\" where checkpoint notify "
"to save lookup table variables"
" to directory specified.")
.AsDispensable();
} }
}; };
class SaveOpVarTypeInference : public framework::VarTypeInference { class SaveOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output(LOOKUP_TABLE_PATH).front(); auto var_type = framework::proto::VarType::RAW;
ctx->SetType(out_var_name, framework::proto::VarType::RAW); ctx->SetType(LOOKUP_TABLE_PATH, var_type);
} }
}; };
...@@ -169,11 +81,18 @@ class SaveOpShapeInference : public framework::InferShapeBase { ...@@ -169,11 +81,18 @@ class SaveOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override {} void operator()(framework::InferShapeContext *ctx) const override {}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker, REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker,
ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference, ops::SaveOpVarTypeInference, ops::SaveOpShapeInference);
ops::SaveOpShapeInference);
REGISTER_OP_CPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/save_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace operators {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
template <typename DeviceContext, typename T>
class SaveOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto *input_var = ctx.InputVar("X");
auto iname = ctx.Inputs("X").data();
PADDLE_ENFORCE(input_var != nullptr, "Cannot find variable %s for save_op",
iname);
if (input_var->IsType<framework::LoDTensor>()) {
SaveLodTensor(ctx, place, input_var);
} else if (input_var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(ctx, place, input_var);
} else {
PADDLE_ENFORCE(
false,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
}
void SaveLodTensor(const framework::ExecutionContext &ctx,
const platform::Place &place,
const framework::Variable *var) const {
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
auto in_dtype = tensor.type();
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
fout.close();
}
void SaveSelectedRows(const framework::ExecutionContext &ctx,
const platform::Place &place,
const framework::Variable *var) const {
framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH);
PADDLE_ENFORCE(
out_put_var != nullptr,
"Can not find variable kLookupTablePath for SaveSelectedRows");
auto *lt_var = out_put_var->GetMutable<std::string>();
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str());
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close();
}
};
} // namespace operators
} // namespace paddle
...@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
"tensor's rank."); "tensor's rank.");
} }
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims, false);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X) // Only pass LoD when the first dimension of output and Input(X)
...@@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
} }
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims, static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) { const framework::DDim &in_dims,
bool is_runtime) {
size_t num_squeeze_dims = squeeze_dims.size(); size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0; int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false}; bool should_squeeze[9] = {false};
...@@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
// Check current index, the upper limit has beed checked in line 36. // Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE(current >= 0, PADDLE_ENFORCE(current >= 0,
"Invalid axis, the negative axis is out of range."); "Invalid axis, the negative axis is out of range.");
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed " if (is_runtime) {
"should be equal to 1."); PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
}
if (!(should_squeeze[current])) { if (!(should_squeeze[current])) {
++cnt_squeezed_dims; ++cnt_squeezed_dims;
...@@ -104,7 +108,7 @@ class SqueezeOp : public framework::OperatorBase { ...@@ -104,7 +108,7 @@ class SqueezeOp : public framework::OperatorBase {
const platform::Place &place) const override { const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes"); auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims(); auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims); auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true);
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims); attrs["shape"] = framework::vectorize2int(out_dims);
...@@ -224,7 +228,7 @@ class Squeeze2Op : public framework::OperatorBase { ...@@ -224,7 +228,7 @@ class Squeeze2Op : public framework::OperatorBase {
const platform::Place &place) const override { const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes"); auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims(); auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims); auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true);
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims); attrs["shape"] = framework::vectorize2int(out_dims);
......
...@@ -34,8 +34,11 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -34,8 +34,11 @@ class TopkOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape"); PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape");
PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k,
"input must have >= k columns"); if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k,
"input must have >= k columns");
}
framework::DDim dims = input_dims; framework::DDim dims = input_dims;
dims[dims.size() - 1] = k; dims[dims.size() - 1] = k;
......
...@@ -23,7 +23,9 @@ limitations under the License. */ ...@@ -23,7 +23,9 @@ limitations under the License. */
#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#if !defined(__APPLE__) && !defined(_WIN32)
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#endif
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#endif #endif
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <string> #include <string>
...@@ -31,6 +30,8 @@ constexpr static float fraction_of_gpu_memory_to_use = 0.92f; ...@@ -31,6 +30,8 @@ constexpr static float fraction_of_gpu_memory_to_use = 0.92f;
constexpr static float fraction_of_gpu_memory_to_use = 0.5f; constexpr static float fraction_of_gpu_memory_to_use = 0.5f;
#endif #endif
constexpr static float fraction_reserve_gpu_memory = 0.05f;
DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use, DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use,
"Allocate a trunk of gpu memory that is this fraction of the " "Allocate a trunk of gpu memory that is this fraction of the "
"total gpu memory size. Future memory usage will be allocated " "total gpu memory size. Future memory usage will be allocated "
...@@ -38,6 +39,24 @@ DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use, ...@@ -38,6 +39,24 @@ DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use,
"additional trunks of the same size will be requested from gpu " "additional trunks of the same size will be requested from gpu "
"until the gpu has no memory left for another trunk."); "until the gpu has no memory left for another trunk.");
DEFINE_uint64(
initial_gpu_memory_in_mb, 0ul,
"Allocate a trunk of gpu memory whose byte size is specified by "
"the flag. Future memory usage will be allocated from the "
"truck. If the trunk doesn't have enough gpu memory, additional "
"trunks of the gpu memory will be requested from gpu with size "
"specified by FLAGS_reallocate_gpu_memory_in_mb until the gpu has "
"no memory left for the additional trunk. Note: if you set this "
"flag, the memory size set by "
"FLAGS_fraction_of_gpu_memory_to_use will be overrided by this "
"flag. If you don't set this flag, PaddlePaddle will use "
"FLAGS_fraction_of_gpu_memory_to_use to allocate gpu memory");
DEFINE_uint64(reallocate_gpu_memory_in_mb, 0ul,
"If this flag is set, Paddle will reallocate the gpu memory with "
"size specified by this flag. Else Paddle will reallocate by "
"FLAGS_fraction_of_gpu_memory_to_use");
DEFINE_bool( DEFINE_bool(
enable_cublas_tensor_op_math, false, enable_cublas_tensor_op_math, false,
"The enable_cublas_tensor_op_math indicate whether to use Tensor Core, " "The enable_cublas_tensor_op_math indicate whether to use Tensor Core, "
...@@ -180,13 +199,43 @@ void GpuMemoryUsage(size_t *available, size_t *total) { ...@@ -180,13 +199,43 @@ void GpuMemoryUsage(size_t *available, size_t *total) {
} }
size_t GpuMaxAllocSize() { size_t GpuMaxAllocSize() {
return std::max(GpuInitAllocSize(), GpuReallocSize());
}
size_t GpuInitAllocSize() {
if (FLAGS_initial_gpu_memory_in_mb > 0ul) {
// Initial memory will be allocated by FLAGS_initial_gpu_memory_in_mb
return static_cast<size_t>(FLAGS_initial_gpu_memory_in_mb << 20);
}
// FLAGS_initial_gpu_memory_in_mb is 0, initial memory will be allocated by
// fraction
size_t total = 0; size_t total = 0;
size_t available = 0; size_t available = 0;
GpuMemoryUsage(&available, &total); GpuMemoryUsage(&available, &total);
size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);
// Reserve the rest for page tables, etc. return static_cast<size_t>((total - reserving) *
return static_cast<size_t>(total * FLAGS_fraction_of_gpu_memory_to_use); FLAGS_fraction_of_gpu_memory_to_use);
}
size_t GpuReallocSize() {
if (FLAGS_reallocate_gpu_memory_in_mb > 0ul) {
// Additional memory will be allocated by FLAGS_reallocate_gpu_memory_in_mb
return static_cast<size_t>(FLAGS_reallocate_gpu_memory_in_mb << 20);
}
// FLAGS_reallocate_gpu_memory_in_mb is 0, additional memory will be allocated
// by fraction
size_t total = 0;
size_t available = 0;
GpuMemoryUsage(&available, &total);
size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);
return static_cast<size_t>((total - reserving) *
FLAGS_fraction_of_gpu_memory_to_use);
} }
size_t GpuMinChunkSize() { size_t GpuMinChunkSize() {
...@@ -201,16 +250,13 @@ size_t GpuMaxChunkSize() { ...@@ -201,16 +250,13 @@ size_t GpuMaxChunkSize() {
GpuMemoryUsage(&available, &total); GpuMemoryUsage(&available, &total);
VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/" VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/"
<< total / 1024 / 1024 << "M"; << total / 1024 / 1024 << "M";
size_t reserving = static_cast<size_t>(0.05 * total); size_t reserving = static_cast<size_t>(fraction_reserve_gpu_memory * total);
// If available less than minimum chunk size, no usable memory exists. // If available less than minimum chunk size, no usable memory exists.
available = available =
std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(), std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
total - reserving); total - reserving);
// Reserving the rest memory for page tables, etc. size_t allocating = GpuMaxAllocSize();
size_t allocating = static_cast<size_t>(FLAGS_fraction_of_gpu_memory_to_use *
(total - reserving));
PADDLE_ENFORCE_LE(allocating, available, PADDLE_ENFORCE_LE(allocating, available,
"Insufficient GPU memory to allocation."); "Insufficient GPU memory to allocation.");
......
...@@ -60,6 +60,12 @@ void GpuMemoryUsage(size_t *available, size_t *total); ...@@ -60,6 +60,12 @@ void GpuMemoryUsage(size_t *available, size_t *total);
//! Get the maximum allocation size of current GPU device. //! Get the maximum allocation size of current GPU device.
size_t GpuMaxAllocSize(); size_t GpuMaxAllocSize();
//! Get the initial allocation size of current GPU device.
size_t GpuInitAllocSize();
//! Get the re-allocation size of current GPU device.
size_t GpuReallocSize();
//! Get the minimum chunk size for GPU buddy allocator. //! Get the minimum chunk size for GPU buddy allocator.
size_t GpuMinChunkSize(); size_t GpuMinChunkSize();
......
...@@ -453,6 +453,7 @@ function assert_api_spec_approvals() { ...@@ -453,6 +453,7 @@ function assert_api_spec_approvals() {
echo "checking ${API_FILE} change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}" echo "checking ${API_FILE} change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}"
if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then
# NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable. # NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable.
# approval_user_list: velconia 1979255,panyx0718 2887803,XiaoguangHu01 46782768,chengduoZH 30176695,Xreki 12538138,luotao1 6836917,sneaxiy 32832641,tensor-tang 21351065,jacquesqiao 3048612,typhoonzero 13348433,shanyi15 35982308.
if [ "$API_FILE" == "paddle/fluid/API.spec" ];then if [ "$API_FILE" == "paddle/fluid/API.spec" ];then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \ APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 2887803 35982308 46782768 30176695` python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 2887803 35982308 46782768 30176695`
...@@ -462,14 +463,14 @@ function assert_api_spec_approvals() { ...@@ -462,14 +463,14 @@ function assert_api_spec_approvals() {
fi fi
else else
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \ APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803` python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803 1979255 21351065 3048612 13348433 46782768 30176695 12538138 6836917 32832641`
fi fi
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}" echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then if [ "${APPROVALS}" == "FALSE" ]; then
if [ "$API_FILE" == "paddle/fluid/API.spec" ];then if [ "$API_FILE" == "paddle/fluid/API.spec" ];then
echo "You must have one RD (panyx0718 or chengduoZH or XiaoguangHu01) and one PM (shanyi15) approval for the api change! ${API_FILE}" echo "You must have one RD (panyx0718 or chengduoZH or XiaoguangHu01) and one PM (shanyi15) approval for the api change! ${API_FILE}"
else else
echo "You must have panyx0718 approval for the api change! ${API_FILE}" echo "You must have one RD (velconia,panyx0718,XiaoguangHu01,chengduoZH,Xreki,luotao1,sneaxiy,tensor-tang,jacquesqiao,typhoonzero) approval for the api change! ${API_FILE}"
fi fi
exit 1 exit 1
fi fi
...@@ -479,10 +480,10 @@ function assert_api_spec_approvals() { ...@@ -479,10 +480,10 @@ function assert_api_spec_approvals() {
HAS_CONST_CAST=`git diff -U0 upstream/$BRANCH |grep -o -m 1 "const_cast" || true` HAS_CONST_CAST=`git diff -U0 upstream/$BRANCH |grep -o -m 1 "const_cast" || true`
if [ ${HAS_CONST_CAST} ] && [ "${GIT_PR_ID}" != "" ]; then if [ ${HAS_CONST_CAST} ] && [ "${GIT_PR_ID}" != "" ]; then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \ APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803` python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803 1979255 21351065 3048612 13348433 46782768 30176695 12538138 6836917 32832641`
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}" echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then if [ "${APPROVALS}" == "FALSE" ]; then
echo "You must have panyx0718 approval for the const_cast" echo "You must have one RD (velconia,panyx0718,XiaoguangHu01,chengduoZH,Xreki,luotao1,sneaxiy,tensor-tang,jacquesqiao,typhoonzero) approval for the api change! ${API_FILE}"
exit 1 exit 1
fi fi
fi fi
......
...@@ -41,6 +41,8 @@ int main(int argc, char** argv) { ...@@ -41,6 +41,8 @@ int main(int argc, char** argv) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
envs.push_back("fraction_of_gpu_memory_to_use"); envs.push_back("fraction_of_gpu_memory_to_use");
envs.push_back("initial_gpu_memory_in_mb");
envs.push_back("reallocate_gpu_memory_in_mb");
envs.push_back("allocator_strategy"); envs.push_back("allocator_strategy");
#elif __clang__ #elif __clang__
envs.push_back("use_mkldnn"); envs.push_back("use_mkldnn");
......
...@@ -163,7 +163,8 @@ def __bootstrap__(): ...@@ -163,7 +163,8 @@ def __bootstrap__():
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
read_env_flags += [ read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'fraction_of_gpu_memory_to_use', 'initial_gpu_memory_in_mb',
'reallocate_gpu_memory_in_mb', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus',
'sync_nccl_allreduce', 'limit_of_tmp_allocation', 'sync_nccl_allreduce', 'limit_of_tmp_allocation',
......
...@@ -13,13 +13,4 @@ ...@@ -13,13 +13,4 @@
# limitations under the License. # limitations under the License.
from .core import * from .core import *
from .graph import * __all__ = ['Compressor', ]
from .prune import *
__all__ = [
'build_compressor',
'CompressPass',
'ImitationGraph',
'SensitivePruneStrategy',
'MagnitudePruner',
'RatioPruner',
]
...@@ -14,11 +14,9 @@ ...@@ -14,11 +14,9 @@
from . import config from . import config
from .config import * from .config import *
from . import compress_pass from . import compressor
from .compress_pass import * from .compressor import *
from . import strategy from . import strategy
from .strategy import * from .strategy import *
from . import pass_builder
from .pass_builder import *
__all__ = config.__all__ + compress_pass.__all__ + strategy.__all__ + pass_builder.__all__ __all__ = config.__all__ + compressor.__all__ + strategy.__all__
此差异已折叠。
...@@ -17,7 +17,7 @@ import funcsigs ...@@ -17,7 +17,7 @@ import funcsigs
import yaml import yaml
from collections import OrderedDict from collections import OrderedDict
from ..prune import * from ..prune import *
from .compress_pass import * from ..quantization import *
from .strategy import * from .strategy import *
__all__ = ['ConfigFactory'] __all__ = ['ConfigFactory']
...@@ -29,15 +29,10 @@ class ConfigFactory(object): ...@@ -29,15 +29,10 @@ class ConfigFactory(object):
def __init__(self, config): def __init__(self, config):
"""Init a factory from configure file.""" """Init a factory from configure file."""
self.instances = {} self.instances = {}
self.compressor = {}
self.version = None self.version = None
self._parse_config(config) self._parse_config(config)
def get_compress_pass(self):
"""
Get compress pass from factory.
"""
return self.instance('compress_pass')
def instance(self, name): def instance(self, name):
""" """
Get instance from factory. Get instance from factory.
...@@ -59,8 +54,16 @@ class ConfigFactory(object): ...@@ -59,8 +54,16 @@ class ConfigFactory(object):
args = {} args = {}
for key in keys: for key in keys:
value = attrs[key] value = attrs[key]
if isinstance(value, str) and value.lower() == 'none':
value = None
if isinstance(value, str) and value in self.instances: if isinstance(value, str) and value in self.instances:
value = self.instances[value] value = self.instances[value]
if isinstance(value, list):
for i in range(len(value)):
if isinstance(value[i],
str) and value[i] in self.instances:
value[i] = self.instances[value[i]]
args[key] = value args[key] = value
self.instances[name] = class_(**args) self.instances[name] = class_(**args)
return self.instances.get(name) return self.instances.get(name)
...@@ -76,16 +79,23 @@ class ConfigFactory(object): ...@@ -76,16 +79,23 @@ class ConfigFactory(object):
assert self.version == int(key_values['version']) assert self.version == int(key_values['version'])
# parse pruners # parse pruners
if key == 'pruners' or key == 'strategies': if key == 'distillers' or key == 'pruners' or key == 'quantizers' or key == 'strategies':
instances = key_values[key] instances = key_values[key]
for name in instances: for name in instances:
self._new_instance(name, instances[name]) self._new_instance(name, instances[name])
if key == 'compress_pass': if key == 'compressor':
compress_pass = self._new_instance(key, key_values[key]) self.compressor['strategies'] = []
for name in key_values[key]['strategies']: self.compressor['epoch'] = key_values[key]['epoch']
strategy = self.instance(name) if 'init_model' in key_values[key]:
compress_pass.add_strategy(strategy) self.compressor['init_model'] = key_values[key][
'init_model']
self.compressor['checkpoint_path'] = key_values[key][
'checkpoint_path']
if 'strategies' in key_values[key]:
for name in key_values[key]['strategies']:
strategy = self.instance(name)
self.compressor['strategies'].append(strategy)
if key == 'include': if key == 'include':
for config_file in key_values[key]: for config_file in key_values[key]:
......
...@@ -20,7 +20,7 @@ class Strategy(object): ...@@ -20,7 +20,7 @@ class Strategy(object):
Base class for all strategies. Base class for all strategies.
""" """
def __init__(self, start_epoch=0, end_epoch=10): def __init__(self, start_epoch=0, end_epoch=0):
""" """
Args: Args:
start_epoch: The first epoch to apply the strategy. start_epoch: The first epoch to apply the strategy.
...@@ -29,7 +29,7 @@ class Strategy(object): ...@@ -29,7 +29,7 @@ class Strategy(object):
self.start_epoch = start_epoch self.start_epoch = start_epoch
self.end_epoch = end_epoch self.end_epoch = end_epoch
def on_compress_begin(self, context): def on_compression_begin(self, context):
pass pass
def on_epoch_begin(self, context): def on_epoch_begin(self, context):
...@@ -44,5 +44,5 @@ class Strategy(object): ...@@ -44,5 +44,5 @@ class Strategy(object):
def on_batch_end(self, context): def on_batch_end(self, context):
pass pass
def on_compress_end(self, context): def on_compression_end(self, context):
pass pass
version: 1.0
pruners:
pruner_1:
class: 'RatioPruner'
ratios:
'conv1_1.w': 0.3
'conv1_2.w': 0.4
'*': 0.9
group_dims:
'*': [1, 2, 3]
criterions:
'*': 'l1-norm'
strategies:
strategy_1:
class: 'SensitivePruneStrategy'
pruner: 'pruner_1'
start_epoch: 0
end_epoch: 10
delta_rate: 0.20
acc_loss_threshold: 0.2
sensitivities:
'conv1_1.w': 0.4
compress_pass:
class: 'CompressPass'
epoch: 100
strategies:
- strategy_1
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册