提交 eaeff90e 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_some_yaml_config

cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api)
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api)
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api string_helper)
if (WITH_DISTRIBUTE)
cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper)
......
......@@ -171,10 +171,10 @@ ProcessGroupGloo::GlooTask::GlooTask(int rank,
"Only CPU place is supported for ProcessGroupGloo."));
}
ProcessGroupGloo::ProcessGroupGloo(const std::shared_ptr<GlooStore>& store,
int rank, int world_size,
const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size), _tag(0), _store(store) {
ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, int rank,
int world_size, const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size), _tag(0), _store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
auto prefix_store =
::gloo::rendezvous::PrefixStore(std::to_string(0), *_store);
......
......@@ -52,8 +52,7 @@ class ProcessGroupGloo : public ProcessGroup {
class GlooStore : public ::gloo::rendezvous::Store {
public:
explicit GlooStore(
const std::shared_ptr<paddle::distributed::TCPStore>& store)
explicit GlooStore(const std::shared_ptr<paddle::distributed::Store>& store)
: _store(store) {}
~GlooStore() = default;
......@@ -87,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroup {
}
protected:
std::shared_ptr<paddle::distributed::TCPStore> _store;
std::shared_ptr<paddle::distributed::Store> _store;
};
class GlooOptions {
......@@ -100,9 +99,9 @@ class ProcessGroupGloo : public ProcessGroup {
std::shared_ptr<::gloo::transport::Device> device;
};
explicit ProcessGroupGloo(const std::shared_ptr<GlooStore>& store, int rank,
int world_size,
std::shared_ptr<GlooOptions> options);
explicit ProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, int rank,
int world_size, std::shared_ptr<GlooOptions> options);
~ProcessGroupGloo() = default;
......@@ -145,7 +144,7 @@ class ProcessGroupGloo : public ProcessGroup {
protected:
uint32_t _tag;
std::shared_ptr<gloo::rendezvous::Context> _context;
std::shared_ptr<GlooStore> _store;
std::shared_ptr<::gloo::rendezvous::Store> _store;
};
} // namespace distributed
......
......@@ -17,6 +17,20 @@
namespace paddle {
namespace distributed {
static Backend TransToBackend(platform::Place place) {
static const std::map<phi::AllocationType, Backend> type_backend = {
{phi::AllocationType::GPU, Backend::GPU},
{phi::AllocationType::CPU, Backend::CPU},
};
phi::AllocationType type = place.GetType();
auto it = type_backend.find(type);
PADDLE_ENFORCE_EQ(it != type_backend.end(), true,
platform::errors::InvalidArgument(
"Place type (%s) is not supported. ", place));
return it->second;
}
std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor> tensors,
const std::vector<bool> &is_sparse_gradient,
......@@ -297,10 +311,18 @@ EagerReducer::EagerReducer(
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook));
gradnode_index_map_[grad_node.get()] = global_var_index;
}
vars_marked_ready_.resize(tensors_.size(), false);
local_used_vars_.resize(tensors_.size(), 0);
if (find_unused_vars_each_step_) {
global_used_vars_ = paddle::experimental::empty(
ScalarArray({static_cast<int32_t>(tensors_.size())}), DataType::INT32,
TransToBackend(inner_place_));
}
}
std::shared_ptr<egr::GradNodeBase> EagerReducer::GetGradNodeFromTensor(
......@@ -341,21 +363,10 @@ void EagerReducer::InitializeGroups(
} else {
// process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group);
experimental::Backend backend;
switch (inner_place_.GetType()) {
case phi::AllocationType::GPU:
backend = experimental::Backend::GPU;
break;
case phi::AllocationType::CPU:
backend = experimental::Backend::CPU;
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Place type (%s) is not supported. ", inner_place_));
break;
}
// experimental::Backend backend = TransToBackend(inner_place_);
group.dense_contents_ = paddle::experimental::empty(
ScalarArray({group.all_length_}), group.dtype_, backend);
ScalarArray({group.all_length_}), group.dtype_,
TransToBackend(inner_place_));
}
// map tensors to this group by VariableLocator
......@@ -418,6 +429,53 @@ void EagerReducer::InitializeDenseGroups(
p_group->all_length_ = all_length;
}
void EagerReducer::TraverseBackwardGraph(const std::vector<Tensor> &outputs) {
std::queue<egr::GradNodeBase *> queue;
std::set<egr::GradNodeBase *> visited;
for (const auto &output : outputs) {
auto *auto_grad_meta =
static_cast<egr::AutogradMeta *>(output.get_autograd_meta());
if (!auto_grad_meta) continue;
auto shared_grad_node = auto_grad_meta->GetMutableGradNode();
if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
auto_grad_meta->StopGradient()) {
continue;
}
egr::GradNodeBase *grad_node = shared_grad_node.get();
queue.emplace(grad_node);
}
while (!queue.empty()) {
egr::GradNodeBase *node = queue.front();
queue.pop();
const std::vector<std::vector<egr::Edge>> &edges = node->GetEdges();
for (size_t i = 0; i < edges.size(); i++) {
for (size_t j = 0; j < edges[i].size(); j++) {
const egr::Edge &edge = edges[i][j];
auto next_node_shared = edge.GetMutableGradNode();
if (!next_node_shared || !next_node_shared.get()) {
continue;
}
auto *next_node = next_node_shared.get();
const bool was_inserted = visited.insert(next_node).second;
if (was_inserted) {
queue.emplace(next_node);
}
}
}
}
for (const auto &it : gradnode_index_map_) {
if (visited.count(it.first) == 0) {
unused_vars_.push_back(it.second);
VLOG(3) << "[Rank " << process_group_->GetRank() << "]: "
<< "Tensor " << tensors_[it.second].name() << " at index "
<< it.second << " is marked as unused.";
}
}
}
void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
VLOG(3) << "after forward, then reset count for backward.";
grad_need_hooks_ = true;
......@@ -429,6 +487,51 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
// reinitialize vars_marked_ready_ for next iteration
vars_marked_ready_.clear();
vars_marked_ready_.resize(tensors_.size(), false);
PADDLE_ENFORCE_EQ(
groups_need_finalize_, false,
platform::errors::PreconditionNotMet(
"A serious error has occurred here. Please "
"set find_unused_parameters=True to traverse backward graph "
"in each step to prepare reduce in advance. If you have "
"set, There may be several reasons for this error: "
"1) Please note that all forward outputs derived from the module "
"parameters must participate in the calculation of losses and "
"subsequent gradient calculations. If not, the wrapper will hang, "
"waiting for autograd to generate gradients for these parameters. "
"you can use detach or stop_gradient to make the unused parameters "
"detached from the autograd graph. "
"2) Used multiple forwards and one backward. You may be able to wrap "
"multiple forwards in a model."));
// The first var to trigger the unused parameter
has_marked_unused_vars_ = false;
if (find_unused_vars_once_ || find_unused_vars_each_step_) {
unused_vars_.clear();
TraverseBackwardGraph(outputs);
// only check once in first step
find_unused_vars_once_ = false;
}
if (find_unused_vars_each_step_ && unused_vars_.empty()) {
LOG_FIRST_N(WARNING, 1)
<< "All parameters are involved in the backward pass. "
"It is recommended to set find_unused_parameters to False "
"to improve performance. However, if unused parameters "
"appear in subsequent iterative training, then an error "
"will occur. Please make it clear that in the subsequent "
"training, there will be no parameters that are not used "
"in the backward pass, and then set find_unused_parameters";
}
if (unused_vars_.size() == tensors_.size()) {
LOG_FIRST_N(WARNING, 1)
<< "There is no parameter in the device involved "
"in the backward calculation. If there are "
"parameters on other devices involved in the "
"backward, then a serious error will occur here.";
}
}
void EagerReducer::AddDistHook(size_t var_index) {
......@@ -446,36 +549,104 @@ void EagerReducer::AddDistHook(size_t var_index) {
auto &tensor = tensors_[var_index];
const auto &grad_node = GetGradNodeFromTensor(&tensor);
VLOG(3) << "Var[" << var_index << "] [" << (*grad_node).name()
<< "] arrived and triggered disthook";
VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
<< "@Grad] arrived and triggered disthook";
local_used_vars_[var_index] = 1;
if (!has_marked_unused_vars_) {
has_marked_unused_vars_ = true;
for (const auto unused_index : unused_vars_) {
MarkVarReady(unused_index, false);
}
}
MarkVarReady(var_index, true);
}
void EagerReducer::MarkVarReady(const size_t var_index,
const bool is_used_var) {
VLOG(3) << "Tensor[" << var_index << "][" << tensors_[var_index].name()
<< "] is marked ready.";
// error happened, if the var is ready before.
if (vars_marked_ready_[var_index]) {
auto error_info = string::Sprintf(
"Error happened, when parameter[%d][%s] has been ready before. "
"Please set find_unused_parameters=True to traverse backward graph "
"in each step to prepare reduce in advance. If you have set, "
"there may be several reasons for this error: "
"1) In multiple reentrant backward phase, some parameters are reused."
"2) Using model parameters outside of forward function. Please "
"make sure that model parameters are not shared in concurrent "
"forward-backward passes.",
var_index, tensors_[var_index].name());
PADDLE_ENFORCE_EQ(has_marked_unused_vars_, false,
platform::errors::PreconditionNotMet(error_info));
error_info +=
"3) Unused parameters retrieval is incorrect. "
"The return value of forward will be used to retrieve"
" the unused parameters of the entire model. These "
"gradients of unused parameters will not be synchronized "
"between multiple cards. However, if the unused "
"parameters participate in the backward calculation "
"again at a later time (e.g. after the forward function, "
"the loss calculation uses the unused "
"paramters of the forward and trigger backward), "
"its gradient will be wrong.";
PADDLE_ENFORCE_EQ(has_marked_unused_vars_, true,
platform::errors::PreconditionNotMet(error_info));
} else {
vars_marked_ready_[var_index] = true;
}
groups_need_finalize_ = true;
const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index;
const auto inside_group_index = var_locator.inside_group_index;
auto &group = groups_[group_index];
auto &group_tensor = group.dense_tensors_[inside_group_index];
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
group_tensor
.ShareDataWith(
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
.Resize({grad_tensor.numel()});
vars_marked_ready_[var_index] = true;
const auto length = group.length_[inside_group_index];
if (is_used_var) {
auto *autograd_meta = tensors_[var_index].get_autograd_meta();
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
group_tensor
.ShareDataWith(
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
.Resize({grad_tensor.numel()});
} else {
// TODO(shenliang03): maybe save the memory by avoiding tensor construction
if (!group_tensor.initialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(inner_place_, group.dtype_);
}
if (HasGrad(var_index)) {
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
group_tensor
.ShareDataWith(*(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor->impl())))
.Resize({length});
} else {
VLOG(3) << "Tensor[" << tensors_[var_index].name()
<< "] doesn't have grad";
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
group_tensor.Resize({static_cast<int64_t>(length)});
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
}
}
if (--group.pending_ == 0) {
// can start allreduce
MarkGroupReady(group_index);
}
if (next_group_ == groups_.size()) {
FinalizeBackward();
}
}
void EagerReducer::MarkGroupReady(size_t group_index) {
......@@ -501,6 +672,92 @@ void EagerReducer::MarkGroupReady(size_t group_index) {
}
}
bool EagerReducer::HasGrad(size_t var_index) {
auto grad = egr::EagerUtils::mutable_grad(tensors_[var_index]);
if (grad && grad->is_initialized()) {
return true;
} else {
return false;
}
}
void EagerReducer::ProcessUnusedDenseVars() {
// The calculation stream must be used here to
// avoid conflicts with communication.
VLOG(3) << "Local used vars : "
<< string::join_strings(local_used_vars_, ',');
const auto *dev_ctx =
platform::DeviceContextPool::Instance().Get(inner_place_);
auto *global_used_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(global_used_vars_.impl())
.get();
framework::TensorFromVector<int32_t>(local_used_vars_, *dev_ctx,
global_used_tensor);
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {global_used_vars_};
process_group_->AllReduce(reduce_tensors, opts)->Synchronize();
framework::TensorToVector<int>(*global_used_tensor, *dev_ctx,
&local_used_vars_);
dev_ctx->Wait();
// sync compute stream to get global used var message,
// but maybe affect speed performance
VLOG(3) << "Global used vars : "
<< string::join_strings(local_used_vars_, ',');
for (const auto var_index : unused_vars_) {
const bool global_unused = (local_used_vars_[var_index] == 0);
// global used but local unused, set grad
VLOG(3) << "[Rank " << process_group_->GetRank() << "]: "
<< "Var [" << var_index << "] [" << tensors_[var_index].name()
<< "] global_unused: " << global_unused
<< " has grad: " << HasGrad(var_index);
if (!global_unused) {
VLOG(3) << "Set Tensor[" << var_index << "]'s Grad for [Rank "
<< process_group_->GetRank() << "]";
const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index;
const auto &group = groups_[group_index];
const auto inside_group_index = var_locator.inside_group_index;
auto &src_tensor = group.dense_tensors_[inside_group_index];
Tensor grad_value(std::make_shared<phi::DenseTensor>(src_tensor));
auto dest_var_base = tensors_[var_index];
auto grad_tensor = egr::EagerUtils::mutable_grad(dest_var_base);
grad_tensor->copy_(grad_value, inner_place_, true);
grad_tensor->reshape(dest_var_base.shape());
}
}
}
void EagerReducer::FinalizeBackward() {
groups_need_finalize_ = false;
grad_need_hooks_ = false;
for (auto &group : groups_) {
group.task->Synchronize();
}
for (auto &group : groups_) {
group.SplitTensors(inner_place_);
}
if (find_unused_vars_each_step_) {
ProcessUnusedDenseVars();
local_used_vars_.clear();
local_used_vars_.resize(tensors_.size(), 0);
VLOG(3) << "ProcessUnusedDenseVars is finished.";
}
VLOG(3) << "In the batch, Reducer is finished.";
}
void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
const int curr_group_index) {
// The overall timeline: concat > div_nranks > allreduce > split
......@@ -513,24 +770,14 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
group->ConcatTensors(inner_place_);
// div nranks
double scaling = 1.0 / nranks_;
paddle::experimental::scale_(group->dense_contents_, scaling, 0.0, false);
paddle::experimental::scale_(group->dense_contents_, 1.0 / nranks_, 0.0,
false);
// all_reduce
std::vector<Tensor> reduce_tensors = {group->dense_contents_};
tasks_.push_back(process_group_->AllReduce(reduce_tensors, opts));
group->task = process_group_->AllReduce(reduce_tensors, opts);
if (tasks_.size() == groups_.size()) {
for (size_t index = 0; index < tasks_.size(); index++) {
auto &task = tasks_.back();
task->Synchronize();
tasks_.pop_back();
}
for (size_t index = 0; index < groups_.size(); index++) {
auto &group = groups_[index];
group.SplitTensors(inner_place_);
}
}
// split in FinalizeBackward()
}
std::ostream &operator<<(std::ostream &out, const EagerGroup &group) {
......
......@@ -28,6 +28,8 @@
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/utils/string/string_helper.h"
namespace paddle {
namespace distributed {
......@@ -35,6 +37,7 @@ using Tensor = paddle::experimental::Tensor;
using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;
using ScalarArray =
paddle::experimental::ScalarArrayBase<paddle::experimental::Tensor>;
using Backend = paddle::experimental::Backend;
std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor>, const std::vector<bool> &is_sparse_gradient,
......@@ -61,6 +64,9 @@ class EagerGroup {
// external message of group
phi::DataType dtype_;
// help to sync
std::shared_ptr<ProcessGroup::Task> task;
// context is used to select the stream for concat
void ConcatTensors(const platform::Place &);
......@@ -98,6 +104,10 @@ class EagerReducer {
void MarkVarReady(const size_t var_index, const bool is_used_var);
void MarkGroupReady(const size_t group_index);
void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index);
void FinalizeBackward();
void TraverseBackwardGraph(const std::vector<Tensor> &outputs);
void ProcessUnusedDenseVars();
bool HasGrad(size_t var_index);
private:
std::vector<Tensor> tensors_;
......@@ -105,7 +115,6 @@ class EagerReducer {
std::vector<bool> is_sparse_gradient_;
std::shared_ptr<distributed::ProcessGroup> process_group_;
std::vector<size_t> group_size_limits_;
bool find_unused_vars_each_step_;
std::vector<EagerGroup> groups_;
std::vector<TensorLocator> variable_locators_;
......@@ -113,12 +122,20 @@ class EagerReducer {
platform::Place inner_place_;
size_t next_group_ = 0;
int64_t nranks_ = -1;
std::vector<std::shared_ptr<paddle::distributed::ProcessGroup::Task>> tasks_;
bool grad_need_hooks_{false};
std::vector<bool> vars_marked_ready_;
std::vector<int> local_used_vars_;
std::vector<int32_t> local_used_vars_;
// Following variables are to help unused vars
std::vector<size_t> unused_vars_;
std::map<egr::GradNodeBase *, size_t> gradnode_index_map_;
bool has_marked_unused_vars_{false};
bool find_unused_vars_each_step_{false};
bool find_unused_vars_once_{true};
bool groups_need_finalize_{false};
Tensor global_used_vars_;
};
} // namespace distributed
......
......@@ -86,9 +86,9 @@ paddle::experimental::Tensor scale(const paddle::experimental::Tensor& x,
scale_node->SetTensorWrappers_X({x});
// Set Grad out rank as same as fwd input and set stop gradient to bwd
scale_node->SetGradOutMeta(p_autograd_in, /*slot id*/ 0);
scale_node->SetGradOutMeta(x, /*slot id*/ 0);
// Set Grad out rank as same as fwd input and set stop gradient to bwd
scale_node->SetGradInMeta(p_autograd_out, /*slot id*/ 0);
scale_node->SetGradInMeta(out, /*slot id*/ 0);
// Set History for output set current Grad Node for
EagerUtils::SetHistory(p_autograd_out, scale_node);
......
......@@ -30,7 +30,8 @@ namespace egr_utils_api {
bool IsLeafTensor(const paddle::experimental::Tensor& target) {
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(target);
if (std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node)) {
if (!grad_node ||
std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node)) {
return true;
}
......
......@@ -27,6 +27,7 @@ add_custom_target(eager_final_state_codegen
set(tmp_python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h")
set(python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function_impl.h")
add_custom_target(eager_final_state_python_c_codegen
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path}"
......
......@@ -657,6 +657,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
else:
# Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(fwd_api_name)
......@@ -793,7 +794,7 @@ def GenerateNodeCreationCodes(
set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
set_grad_out_meta = f" grad_node->SetGradOutMeta({input_autograd_meta_name}, {pos});"
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges)
......@@ -810,17 +811,18 @@ def GenerateNodeCreationCodes(
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_grad_in_meta = f" grad_node->SetGradInMeta({output_autograd_meta_name}, {pos});"
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result[{pos}]);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result[{pos}], {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_retain_grad_list.append(set_retain_grad)
set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list)
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
......
......@@ -517,11 +517,11 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
// TODO(jiabin): Should we erase it or find a more efficient way.
node_input_buffers_dict.erase(node);
// Prepare GradTensorHolder for next node
const std::vector<std::vector<Edge>>& edges = node->GetEdges();
PADDLE_ENFORCE(edges.size() == grad_output_tensors.size() || edges.empty(),
paddle::platform::errors::Fatal(
"Number of edges should be either empty ( for leaf node "
......@@ -532,6 +532,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
for (size_t i = 0; i < edges.size(); i++) {
for (size_t j = 0; j < edges[i].size(); j++) {
const Edge& edge = edges[i][j];
auto edge_rank = edge.GetEdgeRankInfo();
// Since we make edge has as same rank as bwd outputs, we indexing them
// with
......@@ -545,6 +546,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
grad_output_tensors[i].empty()) {
continue;
}
PADDLE_ENFORCE_LT(
j, grad_output_tensors[i].size(),
paddle::platform::errors::Fatal(
......
......@@ -15,10 +15,16 @@
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
......@@ -33,7 +39,6 @@ GradNodeBase::GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num) {
VLOG(6) << "Construct GradNodeBase";
bwd_in_meta_.resize(bwd_in_slot_num);
bwd_out_meta_.resize(bwd_out_slot_num);
// adj_edges has the same num as backward outputs
adj_edges_.resize(bwd_out_slot_num);
}
......@@ -44,24 +49,20 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) {
"Given slot id is out of range of adj_edges outter size, "
"adj_edges is designed to has the same size of grad "
"inputs's slot num."));
for (const auto& meta : *metas) {
for (size_t i = 0; i < metas->size(); i++) {
const auto& meta = (*metas)[i];
// adj_edges has as same rank as fwd inputs, and record it's output rank
// from
// its pre-ops
if (meta && !meta->StopGradient()) {
auto node = meta->GetMutableGradNode();
if (node && node.get()) {
VLOG(6) << "Add Edges for slot: " << slot_id
<< " which is: " << meta->GetMutableGradNode()->name();
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
} else {
if (!node || !node.get()) {
meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
VLOG(6) << "Add Edges for slot: " << slot_id
<< " which is: " << meta->GetMutableGradNode()->name();
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
}
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
}
}
}
......@@ -73,130 +74,205 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) {
"Given slot id is out of range of adj_edges outter size, "
"adj_edges is designed to has the same size of grad "
"inputs's slot num."));
if (meta && !meta->StopGradient()) {
auto node = meta->GetMutableGradNode();
if (node && node.get()) {
VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
<< this->name() << " to " << meta->GetMutableGradNode()->name();
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
} else {
if (!node || !node.get()) {
meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
<< this->name() << " to " << meta->GetMutableGradNode()->name();
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
}
VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
<< this->name() << " to " << meta->GetMutableGradNode()->name();
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
}
}
const std::vector<GradSlotMeta>& GradNodeBase::InputMeta() const {
const std::vector<std::vector<GradSlotMeta>>& GradNodeBase::InputMeta() const {
return bwd_in_meta_;
}
const std::vector<GradSlotMeta>& GradNodeBase::OutputMeta() const {
const std::vector<std::vector<GradSlotMeta>>& GradNodeBase::OutputMeta() const {
return bwd_out_meta_;
}
void GradNodeBase::SetGradInMeta(std::vector<AutogradMeta*>* fwd_out,
void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
size_t slot_rank) {
size_t slot_size = fwd_out->size();
auto* fwd_out_meta = egr::EagerUtils::nullable_autograd_meta(fwd_out);
PADDLE_ENFORCE_LE(
slot_rank, (bwd_in_meta_.size() - 1),
paddle::platform::errors::InvalidArgument(
"Slot Rank should less equal than bwd_in_meta_ size, since "
"bwd_in_meta_ is designed to hold as same num as backward "
"inputs."));
auto& meta = bwd_in_meta_.at(slot_rank);
PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
paddle::platform::errors::PreconditionNotMet(
"Bwd_in_meta should only be init once, addition "
"initialization for it is forbidden. If you got this "
"error, it indicates bugs in framework."));
// Init stop gradient vector before use to avoid push back
meta.Init(slot_size);
for (size_t i = 0; i < slot_size; i++) {
PADDLE_ENFORCE_NOT_NULL((*fwd_out)[i],
paddle::platform::errors::PreconditionNotMet(
"Bwd_in_meta should only be called while "
"autograd_meta is not null. If you got this "
"error, it indicates bugs in framework."));
if ((*fwd_out)[i]->StopGradient()) {
// Set Stop Gradient only when its true or non-initialized autograd_meta,
// since all default value is false.
meta.SetStopGradient(i, (*fwd_out)[i]->StopGradient());
auto& metas = bwd_in_meta_.at(slot_rank);
if (metas.size() == 0) {
metas.resize(1);
}
auto& meta = metas[0];
meta.SetStopGradient(fwd_out_meta->StopGradient());
// Record TensorMeta
if (phi::DenseTensor::classof(fwd_out.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_out.impl().get());
PADDLE_ENFORCE_NE(
dense_tensor->meta().dtype, phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal(
"Attempting to copy DenseTensorMeta with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
if (paddle::framework::IsComplexType(
paddle::framework::TransToProtoVarType(dense_tensor->type()))) {
need_complex_to_real_ = true;
}
} else {
VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument.";
}
}
void GradNodeBase::SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank) {
void GradNodeBase::SetGradInMeta(
const std::vector<paddle::experimental::Tensor>& fwd_out,
size_t slot_rank) {
size_t slot_size = fwd_out.size();
PADDLE_ENFORCE_LE(
slot_rank, (bwd_in_meta_.size() - 1),
paddle::platform::errors::InvalidArgument(
"Slot Rank should less equal than bwd_in_meta_ size, since "
"bwd_in_meta_ is designed to hold as same num as backward "
"inputs."));
auto& meta = bwd_in_meta_.at(slot_rank);
PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
paddle::platform::errors::PreconditionNotMet(
"Bwd_in_meta should only be init once, Additional "
"initialization for it is forbidden. If you got this "
"error, it indicates bugs in framework."));
auto& metas = bwd_in_meta_.at(slot_rank);
// Init stop gradient vector before use to avoid push back
VLOG(7) << "Init bwd_in_meta_ with slot rank: " << slot_rank;
meta.Init(1);
meta.SetStopGradient(0, fwd_out->StopGradient());
if (metas.size() < slot_size) {
VLOG(7) << "Init bwd_in_meta_ with slot rank: " << slot_rank;
metas.resize(slot_size);
}
for (size_t i = 0; i < slot_size; i++) {
auto& meta = metas[i];
const auto& fwd_out_tensor = fwd_out[i];
auto* fwd_out_meta =
egr::EagerUtils::nullable_autograd_meta(fwd_out_tensor);
PADDLE_ENFORCE_NOT_NULL(fwd_out_meta,
paddle::platform::errors::PreconditionNotMet(
"Bwd_in_meta should only be called while "
"autograd_meta is not null. If you got this "
"error, it indicates bugs in framework."));
if (fwd_out_meta->StopGradient()) {
// Set Stop Gradient only when its true or non-initialized autograd_meta,
// since all default value is false.
meta.SetStopGradient(fwd_out_meta->StopGradient());
}
// Record TensorMeta
if (phi::DenseTensor::classof(fwd_out_tensor.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_out_tensor.impl().get());
PADDLE_ENFORCE_NE(
dense_tensor->meta().dtype, phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
if (paddle::framework::IsComplexType(
paddle::framework::TransToProtoVarType(dense_tensor->type()))) {
need_complex_to_real_ = true;
}
} else {
VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta "
"with non-DenseTensor argument.";
}
}
}
void GradNodeBase::SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in,
void GradNodeBase::SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
size_t slot_rank) {
size_t slot_size = fwd_in->size();
auto* fwd_in_meta = egr::EagerUtils::nullable_autograd_meta(fwd_in);
PADDLE_ENFORCE_LE(
slot_rank, (bwd_out_meta_.size() - 1),
(slot_rank + 1), bwd_out_meta_.size(),
paddle::platform::errors::InvalidArgument(
"Slot Rank should less equal than bwd_out_meta_ size, "
"since bwd_out_meta_ is designed to hold as same num as "
"backward outputs."));
auto& meta = bwd_out_meta_.at(slot_rank);
PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
paddle::platform::errors::PreconditionNotMet(
"Bwd_out_meta should only be init once. Additional "
"initialization for it is forbidden. If you got this "
"error, it indicates bugs in framework."));
auto& metas = bwd_out_meta_.at(slot_rank);
// Init stop gradient vector before use to avoid push back
meta.Init(slot_size);
for (size_t i = 0; i < slot_size; i++) {
if (!(*fwd_in)[i]) {
meta.SetStopGradient(i, true);
continue;
}
if ((*fwd_in)[i]->StopGradient()) {
// Set Stop Gradient only when its true or non-initialized autograd_meta,
// since all default value is false.
meta.SetStopGradient(i, (*fwd_in)[i]->StopGradient());
if (metas.size() == 0) {
metas.resize(1);
}
auto& meta = metas[0];
if (fwd_in_meta) {
meta.SetStopGradient(fwd_in_meta->StopGradient());
} else {
meta.SetStopGradient(true);
}
// Record TensorMeta
if (fwd_in.impl() && fwd_in.impl().get()) {
if (phi::DenseTensor::classof(fwd_in.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_in.impl().get());
PADDLE_ENFORCE_NE(
dense_tensor->meta().dtype, phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
}
} else {
VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument.";
}
}
void GradNodeBase::SetGradOutMeta(AutogradMeta* fwd_in, size_t slot_rank) {
void GradNodeBase::SetGradOutMeta(
const std::vector<paddle::experimental::Tensor>& fwd_in, size_t slot_rank) {
size_t slot_size = fwd_in.size();
PADDLE_ENFORCE_LE(
(slot_rank + 1), bwd_out_meta_.size(),
slot_rank, (bwd_out_meta_.size() - 1),
paddle::platform::errors::InvalidArgument(
"Slot Rank should less equal than bwd_out_meta_ size, "
"since bwd_out_meta_ is designed to hold as same num as "
"backward outputs."));
auto& meta = bwd_out_meta_.at(slot_rank);
PADDLE_ENFORCE_EQ(meta.IsInitialized(), false,
paddle::platform::errors::PreconditionNotMet(
"Bwd_out_meta should only be init once. Additional "
"initialization for it is forbidden. If you got this "
"error, it indicates bugs in framework."));
auto& metas = bwd_out_meta_.at(slot_rank);
// Init stop gradient vector before use to avoid push back
meta.Init(1);
if (fwd_in) {
meta.SetStopGradient(0, fwd_in->StopGradient());
} else {
meta.SetStopGradient(0, true);
if (metas.size() < slot_size) {
metas.resize(slot_size);
}
for (size_t i = 0; i < slot_size; i++) {
const auto& fwd_in_tensor = fwd_in[i];
auto& meta = metas[i];
auto* fwd_in_meta = egr::EagerUtils::nullable_autograd_meta(fwd_in_tensor);
if (fwd_in_meta) {
// Set Stop Gradient only when its true or non-initialized autograd_meta,
// since all default value is false.
meta.SetStopGradient(fwd_in_meta->StopGradient());
}
// Record TensorMeta
if (fwd_in_tensor.impl() && fwd_in_tensor.impl().get()) {
if (phi::DenseTensor::classof(fwd_in_tensor.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_in_tensor.impl().get());
PADDLE_ENFORCE_NE(dense_tensor->meta().dtype, phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal(
"Attempting to copy DenseTensorMeta with "
"phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
}
} else {
VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta "
"with non-DenseTensor argument.";
}
}
}
......@@ -207,12 +283,8 @@ void GradNodeBase::SetDefaultGradInOutMeta() {
"meta setter, other size of inputs and outputs should "
"create with Setter and Getters"));
// Default stop_gradient is false and slot id is 0, slot size is 1;
bwd_out_meta_[0].Init(1);
bwd_in_meta_[0].Init(1);
}
const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
return adj_edges_;
bwd_out_meta_[0].resize(1);
bwd_in_meta_[0].resize(1);
}
int64_t GradNodeBase::RegisterGradientHook(
......@@ -222,6 +294,10 @@ int64_t GradNodeBase::RegisterGradientHook(
return next_hook_id_++;
}
const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
return adj_edges_;
}
std::vector<std::vector<paddle::experimental::Tensor>>
GradNodeBase::ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
......@@ -270,4 +346,45 @@ GradNodeBase::ApplyGradientHooks(
return outs;
}
void GradNodeBase::HandleComplexGradToRealGrad(
std::vector<std::vector<paddle::experimental::Tensor>>* out_grads) {
for (size_t slot_id = 0; slot_id < out_grads->size(); slot_id++) {
const std::vector<paddle::experimental::Tensor>& slot_out_grads =
(*out_grads)[slot_id];
for (size_t rank_id = 0; rank_id < slot_out_grads.size(); rank_id++) {
const GradSlotMeta& slot_meta = bwd_out_meta_[slot_id][rank_id];
PADDLE_ENFORCE(
slot_meta.HasTensorMeta() > 0,
paddle::platform::errors::Fatal(
"We require TensorMeta in GradInputMeta() to obtain forward data "
"types."
"However, no TensorMeta is detected in bwd_out_meta_."));
auto fwd_data_type = paddle::framework::TransToProtoVarType(
slot_meta.GetTensorMeta().dtype);
const paddle::experimental::Tensor& grad = slot_out_grads[rank_id];
if (paddle::framework::IsComplexType(fwd_data_type)) continue;
// Only Handle Complex To Real for DenseTensor for now
if (phi::DenseTensor::classof(grad.impl().get())) {
phi::DenseTensor* grad_dense_tensor =
static_cast<phi::DenseTensor*>(grad.impl().get());
auto curr_data_type =
paddle::framework::TransToProtoVarType(grad_dense_tensor->type());
if (!paddle::framework::IsComplexType(curr_data_type)) continue;
// Convert Complex GradOut to Real
auto out = std::make_shared<phi::DenseTensor>();
paddle::framework::TransComplexToReal(fwd_data_type, curr_data_type,
*grad_dense_tensor, out.get());
(*out_grads)[slot_id][rank_id].set_impl(out);
}
}
}
}
} // namespace egr
......@@ -57,21 +57,28 @@ class AutogradMeta;
class GradSlotMeta {
public:
GradSlotMeta() = default;
void Init(size_t size) {
size_ = static_cast<int>(size);
stop_gradient_.resize(size, false);
bool IsStopGradient() const { return stop_gradient_; }
void SetStopGradient(bool stop_gradient = true) {
stop_gradient_ = stop_gradient;
}
bool IsInitialized() const { return size_ != -1; }
bool IsStopGradient(size_t rank) const { return stop_gradient_[rank]; }
int Size() const { return size_; }
void SetStopGradient(size_t rank, bool stop_gradient = true) {
stop_gradient_.at(rank) = stop_gradient;
void SetTensorMeta(const phi::DenseTensorMeta& meta) {
meta_ = std::make_shared<phi::DenseTensorMeta>(meta);
}
bool HasTensorMeta() const { return meta_ && meta_.get(); }
const phi::DenseTensorMeta& GetTensorMeta() const {
if (!HasTensorMeta()) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"meta_ of GradSlotMeta has not been initialized yet."
"You're expected to check Edge availability with HasTensorMeta()"
"before calling GetTensorMeta() interface."));
}
return *meta_.get();
}
private:
int size_{-1};
std::vector<bool> stop_gradient_{false};
bool stop_gradient_{false};
std::shared_ptr<phi::DenseTensorMeta> meta_ = nullptr;
};
class GradNodeBase {
......@@ -112,25 +119,30 @@ class GradNodeBase {
void AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id);
void AddEdges(AutogradMeta* meta, size_t slot_id);
/**
* GetEdges is designed to get all edges of current node**/
const std::vector<std::vector<Edge>>& GetEdges() const;
// adj_edges were moved inside OutputMeta(), so no available direct access
// from GradNodeBase.
// To access Edges, get GradSlotMeta by calling OutputMeta(), then use
// slot_meta.GetEdge()
/**
* Get Input Meta of current Grad node**/
const std::vector<GradSlotMeta>& InputMeta() const;
const std::vector<std::vector<GradSlotMeta>>& InputMeta() const;
/**
* Get Output Meta of current Grad node**/
const std::vector<GradSlotMeta>& OutputMeta() const;
const std::vector<std::vector<GradSlotMeta>>& OutputMeta() const;
/**
* Set bwd ins and outs info with forward vars
* **/
void SetGradInMeta(std::vector<AutogradMeta*>* fwd_out, size_t slot_rank);
void SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank);
void SetGradInMeta(const std::vector<paddle::experimental::Tensor>& fwd_out,
size_t slot_rank);
void SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
size_t slot_rank);
void SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in, size_t slot_rank);
void SetGradOutMeta(AutogradMeta* fwd_in, size_t slot_rank);
void SetGradOutMeta(const std::vector<paddle::experimental::Tensor>& fwd_in,
size_t slot_rank);
void SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
size_t slot_rank);
/**
* Default setters for Grad in/out meta this should be used for same special
......@@ -162,11 +174,21 @@ class GradNodeBase {
std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors);
/**
* Handle Complex - Real Type Promotion
* **/
void HandleComplexGradToRealGrad(
std::vector<std::vector<paddle::experimental::Tensor>>* out_grads);
bool NeedComplexToRealConversion() { return need_complex_to_real_; }
virtual std::string name() { return "GradNodeBase"; }
private:
// TODO(jiabin): Use SmallVector instead after merge PR from develop
/**
* GetEdges is designed to get all edges of current node**/
const std::vector<std::vector<Edge>>& GetEdges() const;
private:
// TODO(zhanlve): Merge adj_edges_ into GradOutMeta
// Edges recorded the backward related node info, which indicate all edges
// linked
// by this Grad Node.
......@@ -174,10 +196,10 @@ class GradNodeBase {
std::vector<std::vector<Edge>> adj_edges_;
// bwd_out_meta_ is used to record Grad output info for backward
std::vector<GradSlotMeta> bwd_out_meta_;
std::vector<std::vector<GradSlotMeta>> bwd_out_meta_;
// bwd_in_meta_ used to record Grad input info for backward
std::vector<GradSlotMeta> bwd_in_meta_;
std::vector<std::vector<GradSlotMeta>> bwd_in_meta_;
// Gradient Hooks
// Customer may register a list of hooks which will be called in order during
// backward
......@@ -188,6 +210,8 @@ class GradNodeBase {
/* hook */ std::shared_ptr<TensorHook>>>
gradient_hooks_;
// We handle complex to real conversion only if any complex GradIn is involved
bool need_complex_to_real_ = false;
int64_t next_hook_id_{0};
};
......
......@@ -26,12 +26,13 @@ namespace egr {
* GradTensorHolder should have as same format as forward output **/
class GradTensorHolder {
public:
explicit GradTensorHolder(const std::vector<GradSlotMeta>& meta) {
VLOG(7) << "Init GradTensorHolder with meta size: " << meta.size();
buffer_.resize(meta.size());
explicit GradTensorHolder(
const std::vector<std::vector<GradSlotMeta>>& metas) {
VLOG(7) << "Init GradTensorHolder with meta size: " << metas.size();
buffer_.resize(metas.size());
for (size_t i = 0; i < buffer_.size(); i++) {
VLOG(7) << "Init GradTensorHolder with meta rank: " << meta[i].Size();
buffer_[i].resize(meta[i].Size());
VLOG(7) << "Init GradTensorHolder with meta rank: " << metas[i].size();
buffer_[i].resize(metas[i].size());
}
}
......
......@@ -36,6 +36,15 @@ class TensorWrapper {
explicit TensorWrapper(const paddle::experimental::Tensor& tensor,
bool full_reserved = false,
bool no_need_buffer = false) {
// set inplace_version_snapshot_ according to tensor's current inplace
// version.
if (tensor.impl() && phi::DenseTensor::classof(tensor.impl().get())) {
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(tensor.impl().get());
auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();
inplace_version_snapshot_ = inplace_version_counter.CurrentVersion();
}
/**
* Normally, we should fully reserved all non-output or non-leaf fwd tensor
* here. And for fwd output tensor, we should not reserve its autogradmeta,
......@@ -49,6 +58,7 @@ class TensorWrapper {
}
// shallow copy tensor_impl here
no_need_buffer_ = no_need_buffer;
if (no_need_buffer) {
if (phi::DenseTensor::classof(tensor.impl().get())) {
// Only Copy Meta
......@@ -86,6 +96,7 @@ class TensorWrapper {
// if it's full_reserved just return the full copy of tensor
if (full_reserved_) {
check_inplace_version();
return intermidiate_tensor_;
} else {
std::shared_ptr<GradNodeBase> new_grad_node = grad_node;
......@@ -94,15 +105,52 @@ class TensorWrapper {
intermidiate_tensor_.set_autograd_meta(
std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
p_ab_autograd_meta));
check_inplace_version();
return intermidiate_tensor_;
}
}
void check_inplace_version() {
if (no_need_buffer_) {
VLOG(6) << "There's no need to check inplace_version because "
"no_need_buffer_ is true.";
return;
}
if (intermidiate_tensor_.impl() &&
phi::DenseTensor::classof(intermidiate_tensor_.impl().get())) {
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get());
auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();
uint32_t current_inplace_version =
inplace_version_counter.CurrentVersion();
PADDLE_ENFORCE_EQ(
current_inplace_version, inplace_version_snapshot_,
paddle::platform::errors::PermissionDenied(
"Tensor '%s' used in gradient computation has been "
"modified by an inplace operation. "
"Its version is %d but the expected version is %d. "
"Please fix your code to void calling an inplace operator "
"after using the Tensor which will used in gradient "
"computation.",
intermidiate_tensor_.name(), current_inplace_version,
inplace_version_snapshot_));
VLOG(6) << " The inplace_version_snapshot_ of Tensor '"
<< intermidiate_tensor_.name() << "' is [ "
<< inplace_version_snapshot_ << " ]";
VLOG(6) << " The current_inplace_version of Tensor '"
<< intermidiate_tensor_.name() << "' is [ "
<< current_inplace_version << " ]";
}
}
void clear() { intermidiate_tensor_.reset(); }
private:
bool full_reserved_ = false;
bool no_need_buffer_ = false;
std::pair<size_t, size_t> out_rank_info_;
paddle::experimental::Tensor intermidiate_tensor_;
uint32_t inplace_version_snapshot_ = 0;
};
} // namespace egr
......@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "gtest/gtest.h"
......@@ -23,14 +24,9 @@
TEST(GradNodeInfo, GradSlotMeta) {
auto grad_slot = egr::GradSlotMeta();
CHECK(grad_slot.IsInitialized() == false);
VLOG(6) << "Init GradSlotMeta";
grad_slot.Init(2);
CHECK(grad_slot.IsInitialized() == true);
VLOG(6) << "Set SetStopGradient";
grad_slot.SetStopGradient(0);
CHECK(grad_slot.IsStopGradient(0) == true);
CHECK_EQ(grad_slot.Size(), 2);
grad_slot.SetStopGradient();
CHECK(grad_slot.IsStopGradient() == true);
}
void TestGradNodeBase(bool is_remove_gradient_hook) {
......@@ -56,18 +52,22 @@ void TestGradNodeBase(bool is_remove_gradient_hook) {
->data<float>()[0],
6.0f);
VLOG(6) << "Test Add Edges";
egr::Edge edge0(grad_test_node1, 1, 2);
auto auto_grad0 = std::make_shared<egr::AutogradMeta>(edge0);
egr::Edge tmp_edge0(grad_test_node1, 1, 2);
auto auto_grad0 = std::make_shared<egr::AutogradMeta>(tmp_edge0);
auto_grad0->SetStopGradient(false);
egr::Edge edge1(grad_test_node1, 3, 4);
auto auto_grad1 = std::make_shared<egr::AutogradMeta>(edge1);
egr::Edge tmp_edge1(grad_test_node1, 3, 4);
auto auto_grad1 = std::make_shared<egr::AutogradMeta>(tmp_edge1);
et1.set_autograd_meta(auto_grad1);
auto_grad1->SetStopGradient(false);
grad_test_node0->AddEdges(auto_grad0.get(), 0);
CHECK_EQ(grad_test_node0->GetEdges()[0][0].GetEdgeRankInfo().first,
size_t(1));
CHECK_EQ(grad_test_node0->GetEdges()[0][0].GetEdgeRankInfo().second,
size_t(2));
std::vector<egr::AutogradMeta*> metas = {auto_grad1.get()};
grad_test_node0->AddEdges(&metas, 1);
CHECK_EQ(grad_test_node0->GetEdges()[1][0].GetEdgeRankInfo().first,
size_t(3));
......@@ -76,22 +76,30 @@ void TestGradNodeBase(bool is_remove_gradient_hook) {
VLOG(6) << "Test Set Meta and Get Meta";
auto_grad1->SetStopGradient(true);
grad_test_node0->SetGradInMeta(&metas, 0);
grad_test_node0->SetGradInMeta(auto_grad1.get(), 1);
grad_test_node0->SetGradOutMeta(&metas, 0);
grad_test_node0->SetGradOutMeta(auto_grad1.get(), 1);
CHECK_EQ(grad_test_node0->InputMeta()[0].Size(), 1);
CHECK_EQ(grad_test_node0->InputMeta()[1].Size(), 1);
CHECK(grad_test_node0->OutputMeta()[0].IsStopGradient(0));
CHECK(grad_test_node0->OutputMeta()[1].IsStopGradient(0));
grad_test_node0->SetGradInMeta(et1, 0);
grad_test_node0->SetGradInMeta({et1}, 1);
grad_test_node0->SetGradOutMeta(et1, 0);
grad_test_node0->SetGradOutMeta({et1}, 1);
CHECK_EQ(grad_test_node0->InputMeta()[0].size(), size_t(1));
CHECK_EQ(grad_test_node0->InputMeta()[1].size(), size_t(1));
CHECK_EQ(grad_test_node0->InputMeta()[0][0].GetTensorMeta().dtype,
meta.dtype);
CHECK_EQ(grad_test_node0->InputMeta()[1][0].GetTensorMeta().dtype,
meta.dtype);
CHECK(grad_test_node0->OutputMeta()[0][0].IsStopGradient());
CHECK(grad_test_node0->OutputMeta()[1][0].IsStopGradient());
CHECK_EQ(grad_test_node0->OutputMeta()[0][0].GetTensorMeta().dtype,
meta.dtype);
CHECK_EQ(grad_test_node0->OutputMeta()[1][0].GetTensorMeta().dtype,
meta.dtype);
VLOG(6) << "Test Default Set Meta and Get Meta";
auto grad_test_node2 = std::make_shared<eager_test::GradTestNode>(
/* val */ 5.0, /* in_num */ 1, /* out_num */ 1);
grad_test_node2->SetDefaultGradInOutMeta();
CHECK(grad_test_node2->OutputMeta()[0].IsInitialized());
CHECK(grad_test_node2->OutputMeta()[0].IsStopGradient(0) == false);
CHECK_EQ(grad_test_node2->OutputMeta()[0].Size(), 1);
CHECK_GT(grad_test_node2->OutputMeta()[0].size(), size_t(0));
CHECK(grad_test_node2->OutputMeta()[0][0].IsStopGradient() == false);
CHECK_EQ(grad_test_node2->OutputMeta()[0].size(), size_t(1));
VLOG(6) << "Test Gradient Hook";
auto gradient_hook = [](
......@@ -135,7 +143,17 @@ TEST(GradNodeInfo, GradNodeBase) {
}
TEST(GradNodeInfo, Edge) {
phi::DenseTensorMeta meta =
phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim({1, 1}));
std::shared_ptr<phi::DenseTensor> dt = std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
paddle::experimental::Tensor et1(dt);
auto grad_test_node0 = std::make_shared<eager_test::GradTestNode>(5, 2, 2);
auto auto_grad1 = std::make_shared<egr::AutogradMeta>();
VLOG(6) << "Test Construct Edge";
egr::Edge edge0 = egr::Edge();
CHECK(edge0.IsInitialized() == false);
......@@ -145,13 +163,12 @@ TEST(GradNodeInfo, Edge) {
egr::Edge(grad_test_node0, std::make_pair(size_t(1), size_t(0)));
VLOG(6) << "Test Set Edge's Grad Node";
auto* grad_node = edge1.GetGradNode();
et1.set_autograd_meta(auto_grad1);
grad_node->SetGradInMeta(et1, 0);
CHECK_EQ(grad_node->InputMeta().size(), size_t(2));
auto mt_grad_node = edge1.GetMutableGradNode();
auto auto_grad1 = std::make_shared<egr::AutogradMeta>();
std::vector<egr::AutogradMeta*> metas = {auto_grad1.get()};
// Uninitialized AutogradMeta indicates
mt_grad_node->SetGradInMeta(&metas, 0);
CHECK(grad_node->InputMeta()[0].IsStopGradient(0) == true);
CHECK(grad_node->InputMeta()[0][0].IsStopGradient() == true);
VLOG(6) << "Test Get/Set Edge Rank Info";
CHECK_EQ(edge2.GetEdgeRankInfo().first, size_t(1));
CHECK_EQ(edge2.GetEdgeRankInfo().second, size_t(0));
......
......@@ -30,8 +30,7 @@ PD_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
using namespace egr; // NOLINT
TEST(GradTensorHolder, Constructor) {
GradSlotMeta slot_meta;
slot_meta.Init(1);
std::vector<GradSlotMeta> slot_meta(1);
GradTensorHolder grad_tensor_holder = GradTensorHolder({slot_meta});
GradTensorHolder grad_tensor_holder2 = GradTensorHolder(grad_tensor_holder);
......@@ -72,8 +71,7 @@ TEST(GradTensorHolder, Interfaces) {
paddle::experimental::Tensor et1 = paddle::experimental::Tensor(dt1);
// Constructor empty GradTensorHolder
GradSlotMeta slot_meta;
slot_meta.Init(1);
std::vector<GradSlotMeta> slot_meta(1);
GradTensorHolder grad_tensor_holder =
GradTensorHolder({slot_meta, slot_meta});
......@@ -138,8 +136,7 @@ TEST(GradTensorHolder, SelectedRowsMergeAdd) {
paddle::experimental::Tensor t2(sr2);
// Constructor empty GradTensorHolder
GradSlotMeta slot_meta;
slot_meta.Init(1);
std::vector<GradSlotMeta> slot_meta(1);
GradTensorHolder grad_tensor_holder =
GradTensorHolder({slot_meta, slot_meta});
......
......@@ -37,7 +37,7 @@
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/memory/memcpy.h"
static size_t max_num_benchmark_runs = 5000;
static size_t max_num_benchmark_runs = 4000;
namespace egr {
......
......@@ -66,10 +66,10 @@ inline void run_program_dygraph_function(
grad_node->SetStepScope(step_scope);
// Set Grad out rank as same as fwd input and set stop gradient to bwd
grad_node->SetGradOutMeta(&p_autograd_x, /*slot id*/ 0);
grad_node->SetGradOutMeta(&p_autograd_params, /*slot id*/ 1);
grad_node->SetGradOutMeta(x, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);
grad_node->SetGradInMeta(&p_autograd_outs, 0);
grad_node->SetGradInMeta(deref_out, 0);
// Set Next Edges
grad_node->AddEdges(&p_autograd_x, /*slot id*/ 0);
grad_node->AddEdges(&p_autograd_params, /*slot id*/ 1);
......
......@@ -212,6 +212,27 @@ std::vector<std::shared_ptr<EagerVariable>> EagerUtils::CreateVars(
return res;
}
void EagerUtils::ModifyInplaceInput(
const std::shared_ptr<EagerVariable>& inplace_variable,
paddle::experimental::Tensor* inplace_tensor) {
// Only modify the meta information of the inplace tensor, because
// EagerVariable cannot modify Tensor's meta information after inplace
// op (such as ``reshape``) is executed.
PADDLE_ENFORCE_NOT_NULL(inplace_tensor,
paddle::platform::errors::Fatal(
"Inplace Tensor is null and cannot be modified. "
"We are tring to Modify Inplace Input from its "
"shared_ptr, this error may indicate the inplace "
" input is nullptr"));
if (phi::DenseTensor::classof(inplace_variable->GetTensorBase().get())) {
phi::DenseTensor* variable_dense_tensor =
static_cast<phi::DenseTensor*>(inplace_variable->GetTensorBase().get());
phi::DenseTensor* tensor_dense_tensor =
static_cast<phi::DenseTensor*>(inplace_tensor->impl().get());
tensor_dense_tensor->set_meta(variable_dense_tensor->meta());
}
}
std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs(
const std::vector<std::shared_ptr<EagerVariable>>& outs) {
std::vector<paddle::experimental::Tensor> res;
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h"
......@@ -144,6 +145,19 @@ class EagerUtils {
iter.apply(std::forward<Args>(args)...);
}
static void CheckInplace(const paddle::experimental::Tensor& target,
const AutogradMeta* autograd_meta,
bool require_any_grad) {
if (require_any_grad && autograd_meta) {
PADDLE_ENFORCE_EQ(!autograd_meta->StopGradient() &&
egr::egr_utils_api::IsLeafTensor(target),
false, paddle::platform::errors::InvalidArgument(
"Leaf Var (%s) that doesn't stop gradient "
"can't use inplace strategy.",
target.name()));
}
}
// TensorWrapper Utils
static paddle::experimental::Tensor RecoverTensorWrapper(
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node);
......@@ -171,6 +185,9 @@ class EagerUtils {
static std::vector<std::shared_ptr<EagerVariable>> CreateVars(
const size_t num);
// Construct Tensor From var
static void ModifyInplaceInput(
const std::shared_ptr<EagerVariable>& inplace_variable,
paddle::experimental::Tensor* inplace_tensor);
static std::vector<paddle::experimental::Tensor> GetOutputs(
const std::vector<std::shared_ptr<EagerVariable>>& outs);
static paddle::experimental::Tensor GetOutput(
......
......@@ -32,8 +32,9 @@ USE_OP(conv2d_transpose);
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(gelu);
USE_OP_ITSELF(gelu);
USE_OP_DEVICE_KERNEL(gelu, MKLDNN);
PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle {
namespace framework {
......
......@@ -18,6 +18,7 @@
#include <unordered_set>
#include <boost/logic/tribool.hpp>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -27,10 +28,11 @@ USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(leaky_relu);
USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN);
USE_OP(gelu);
USE_OP_ITSELF(gelu);
USE_OP_ITSELF(relu);
USE_OP_ITSELF(tanh);
USE_OP_DEVICE_KERNEL(tanh, MKLDNN);
PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle {
namespace framework {
......
......@@ -38,7 +38,7 @@ USE_OP(softmax_with_cross_entropy);
USE_OP_ITSELF(reduce_mean);
USE_OP_ITSELF(reduce_sum);
USE_OP_ITSELF(reduce_sum_grad);
USE_OP(reduce_mean_grad);
USE_OP_ITSELF(reduce_mean_grad);
USE_OP_ITSELF(reshape2_grad);
USE_OP(softmax_with_cross_entropy_grad);
USE_OP_ITSELF(elementwise_add_grad);
......
......@@ -628,10 +628,12 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
bool OpSupportGPU(const std::string& op_type) {
// check in new Function kernel first
bool has_phi_kernel = false;
auto& kernel_factory = phi::KernelFactory::Instance();
auto kernel_key_map =
kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type));
for (auto& kernel : kernel_key_map) {
has_phi_kernel = true;
if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) {
return true;
}
......@@ -639,12 +641,19 @@ bool OpSupportGPU(const std::string& op_type) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operator must support GPU
return true;
}
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
return true;
}
}
} else {
if (has_phi_kernel) {
// if has phi kernel, but not find phi gpu kernel and fluid gpu kernel,
// this op doesn't support GPU
return false;
} else {
// All control operator must support GPU
return true;
}
}
......@@ -2347,6 +2356,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
pt_kernel_context->EmplaceBackAttr(vector_int_attr);
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<std::string>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr_it->second));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
......
......@@ -541,6 +541,10 @@ void BuildDygraphPhiKernelContext(
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<std::string>))) {
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
......
......@@ -53,7 +53,11 @@ if [ $7 == ON ]; then
if [[ -e "MobileNetV2.inference.model.tar.gz" ]]; then
echo "MobileNetV2.inference.model.tar.gz has been downloaded."
else
wget -q --no-proxy http://paddle-inference-dist.bj.bcebos.com/MobileNetV2.inference.model.tar.gz
if [ $WIN_DETECT != "" ]; then
wget -q -Y off http://paddle-inference-dist.bj.bcebos.com/MobileNetV2.inference.model.tar.gz
else
wget -q --no-proxy http://paddle-inference-dist.bj.bcebos.com/MobileNetV2.inference.model.tar.gz
fi
tar xzf *.tar.gz
fi
cd ..
......
......@@ -219,6 +219,12 @@ class AllocatorFacadePrivate {
}
InitNaiveBestFitCUDAPinnedAllocator();
#endif
#ifdef PADDLE_WITH_ASCEND_CL
for (int dev_id = 0; dev_id < platform::GetNPUDeviceCount(); ++dev_id) {
InitNaiveBestFitNPUAllocator(platform::NPUPlace(dev_id));
}
InitNaiveBestFitNPUPinnedAllocator();
#endif
#ifdef PADDLE_WITH_XPU
for (int dev_id = 0; dev_id < platform::GetXPUDeviceCount(); ++dev_id) {
InitNaiveBestFitXPUAllocator(platform::XPUPlace(dev_id));
......
......@@ -113,23 +113,5 @@ class BatchNormOpInferVarType
}
};
template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class BatchNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class BatchNormDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
} // namespace paddle
......@@ -76,10 +76,10 @@ class NPUBatchNormOpKernel : public framework::OpKernel<T> {
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<float>(ctx.GetPlace());
variance_out->mutable_data<float>(ctx.GetPlace());
saved_mean->mutable_data<float>(ctx.GetPlace());
saved_variance->mutable_data<float>(ctx.GetPlace());
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
......@@ -170,8 +170,8 @@ class NPUBatchNormGradOpKernel : public framework::OpKernel<T> {
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (d_scale && d_bias) {
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<float>(ctx.GetPlace());
d_bias->mutable_data<float>(ctx.GetPlace());
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
......
......@@ -27,6 +27,9 @@ limitations under the License. */
#endif
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -841,6 +844,8 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(conv2d, Conv2dInferShapeFunctor,
PD_INFER_META(phi::ConvInferMeta));
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType,
ops::Conv2DGradMaker<paddle::framework::OpDesc>,
......@@ -851,6 +856,8 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad,
REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad);
// depthwise convolution op
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d, DepthwiseConv2dInferShapeFunctor,
PD_INFER_META(phi::ConvInferMeta));
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType,
ops::Conv2DGradMaker<paddle::framework::OpDesc>,
......@@ -860,6 +867,8 @@ REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad,
ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_grad_grad, ops::ConvOpDoubleGrad);
DECLARE_INFER_SHAPE_FUNCTOR(conv3d, Conv3dInferShapeFunctor,
PD_INFER_META(phi::ConvInferMeta));
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
ops::ConvOpInferVarType,
ops::Conv3DGradMaker<paddle::framework::OpDesc>,
......
......@@ -356,7 +356,7 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> {
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace());
filter_grad->mutable_data<float>(ctx.GetPlace());
std::vector<int> filter_shape_vec = phi::vectorize<int>(filter->dims());
const auto& runner = NpuOpRunner(
......
......@@ -338,8 +338,6 @@ REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv, ops::DeformableConvCPUKernel<float>,
ops::DeformableConvCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCPUKernel<float>,
ops::DeformableConvGradCPUKernel<double>);
......@@ -446,108 +446,6 @@ __global__ void FilterGradAddupGpuKernel(const int nthreads, const int n,
}
}
template <typename DeviceContext, typename T>
class DeformableConvCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor offset = *ctx.Input<Tensor>("Offset");
const Tensor mask = *ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.cuda_device_context();
const int groups = ctx.Attr<int>("groups");
const int deformable_groups = ctx.Attr<int>("deformable_groups");
const int im2col_step = ctx.Attr<int>("im2col_step");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(phi::vectorize(output->dims()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std::vector<int64_t> col_buffer_shape_vec(filter_shape_vec.size());
col_buffer_shape_vec[0] =
input->dims()[1] * filter.dims()[2] * filter.dims()[3];
col_buffer_shape_vec[1] = im2col_step;
for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) {
col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2];
}
framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec));
std::vector<int64_t> output_buffer_shape_vec(1);
output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] *
output_shape_vec[2] * output_shape_vec[3];
framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec));
Tensor col_buffer;
Tensor output_buffer;
col_buffer = ctx.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx);
output_buffer =
ctx.AllocateTmpTensor<T, DeviceContext>(output_shape, dev_ctx);
int64_t M = output_shape_vec[1] / groups;
int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3];
int64_t K =
input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups;
Tensor weight_3d;
weight_3d.ShareDataWith(filter).Resize(phi::make_ddim({groups, M, K}));
Tensor col_buffer_3d;
col_buffer_3d.ShareDataWith(col_buffer)
.Resize(phi::make_ddim({groups, K, N}));
Tensor output_4d;
output_4d.ShareDataWith(output_buffer)
.Resize(phi::make_ddim({batch_size / im2col_step, groups, M, N}));
output_4d.mutable_data<T>(ctx.GetPlace());
framework::DDim input_shape =
phi::slice_ddim(input->dims(), 1, input->dims().size());
std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape);
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask.numel() / mask.dims()[0];
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
const T* input_ptr = input->data<T>();
const T* offset_ptr = offset.data<T>();
const T* mask_ptr = mask.data<T>();
col_buffer.mutable_data<T>(ctx.GetPlace());
T* col_buffer_ptr = col_buffer.data<T>();
for (int i = 0; i < batch_size / im2col_step; ++i) {
ModulatedDeformableIm2col(
ctx.device_context(), input_ptr + i * im2col_step * input_dim,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec,
col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations,
deformable_groups, col_buffer_ptr);
Tensor output_3d = output_4d.Slice(i, i + 1).Resize(
phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size()));
for (int g = 0; g < groups; ++g) {
Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size()));
Tensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
Tensor output_3d_slice = output_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size()));
blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0),
&output_3d_slice, T(0.0));
}
}
output->ShareDataWith(output_buffer)
.Resize(phi::make_ddim(output_shape_vec));
}
};
template <typename DeviceContext, typename T>
class DeformableConvGradCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -740,9 +638,6 @@ class DeformableConvGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(deformable_conv,
ops::DeformableConvCUDAKernel<CUDA, float>,
ops::DeformableConvCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCUDAKernel<CUDA, float>,
ops::DeformableConvGradCUDAKernel<CUDA, double>);
......@@ -318,102 +318,6 @@ void FilterGradAddupCPUKernel(const int nthreads, const int n, const int height,
}
}
template <typename T>
class DeformableConvCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* offset = ctx.Input<Tensor>("Offset");
auto* mask = ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
const int groups = ctx.Attr<int>("groups");
const int deformable_groups = ctx.Attr<int>("deformable_groups");
const int im2col_step = ctx.Attr<int>("im2col_step");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(phi::vectorize(output->dims()));
// col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w}
std::vector<int64_t> col_buffer_shape_vec(filter_shape_vec.size());
col_buffer_shape_vec[0] =
input->dims()[1] * filter.dims()[2] * filter.dims()[3];
col_buffer_shape_vec[1] = im2col_step;
for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) {
col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2];
}
framework::DDim col_shape(phi::make_ddim(col_buffer_shape_vec));
std::vector<int64_t> output_buffer_shape_vec(1);
output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] *
output_shape_vec[2] * output_shape_vec[3];
framework::DDim output_shape(phi::make_ddim(output_buffer_shape_vec));
Tensor col_buffer;
Tensor output_buffer;
col_buffer = ctx.AllocateTmpTensor<T, CPUDeviceContext>(col_shape, dev_ctx);
output_buffer =
ctx.AllocateTmpTensor<T, CPUDeviceContext>(output_shape, dev_ctx);
int64_t M = output_shape_vec[1] / groups;
int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3];
int64_t K =
input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups;
Tensor weight_3d;
weight_3d.ShareDataWith(filter).Resize(phi::make_ddim({groups, M, K}));
Tensor col_buffer_3d;
col_buffer_3d.ShareDataWith(col_buffer)
.Resize(phi::make_ddim({groups, K, N}));
Tensor output_4d;
output_4d.ShareDataWith(output_buffer)
.Resize(phi::make_ddim({batch_size / im2col_step, groups, M, N}));
output_4d.mutable_data<T>(ctx.GetPlace());
framework::DDim input_shape =
phi::slice_ddim(input->dims(), 1, input->dims().size());
std::vector<int64_t> input_shape_vec = phi::vectorize(input_shape);
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset->numel() / offset->dims()[0];
int input_mask_dim = mask->numel() / mask->dims()[0];
auto blas = phi::funcs::GetBlas<CPUDeviceContext, T>(dev_ctx);
const T* input_ptr = input->data<T>();
const T* offset_ptr = offset->data<T>();
const T* mask_ptr = mask->data<T>();
col_buffer.mutable_data<T>(ctx.GetPlace());
T* col_buffer_ptr = col_buffer.data<T>();
for (int i = 0; i < batch_size / im2col_step; ++i) {
ModulatedDeformableIm2colCPU(
dev_ctx, input_ptr + i * im2col_step * input_dim,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec,
col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations,
deformable_groups, col_buffer_ptr);
Tensor output_3d = output_4d.Slice(i, i + 1).Resize(
phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size()));
// get the product of pixel and weight
for (int g = 0; g < groups; ++g) {
Tensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(weight_3d.dims(), 1, weight_3d.dims().size()));
Tensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
Tensor output_3d_slice = output_3d.Slice(g, g + 1).Resize(
phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size()));
blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0),
&output_3d_slice, T(0.0));
}
}
output->ShareDataWith(output_buffer)
.Resize(phi::make_ddim(output_shape_vec));
}
};
template <typename T>
class DeformableConvGradCPUKernel : public framework::OpKernel<T> {
public:
......
......@@ -9,8 +9,10 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -235,10 +237,13 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(yolo_box, YoloBoxInferShapeFunctor,
PD_INFER_META(phi::YoloBoxInferMeta));
REGISTER_OPERATOR(
yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
YoloBoxInferShapeFunctor);
REGISTER_OP_VERSION(yolo_box)
.AddCheckpoint(
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -25,17 +27,6 @@ class DropoutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Dropout");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -173,7 +164,11 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(dropout, DropoutInferShapeFunctor,
PD_INFER_META(phi::DropoutInferMeta));
REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
ops::DropoutGradOpMaker<paddle::framework::OpDesc>,
ops::DropoutGradOpMaker<paddle::imperative::OpBase>);
ops::DropoutGradOpMaker<paddle::imperative::OpBase>,
DropoutInferShapeFunctor);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
......@@ -198,10 +198,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory);
}
// elementwise_mul & elementwise_div
else {
} else { // elementwise_mul & elementwise_div
platform::BinaryMKLDNNHandler<T> binary_handler(
BINARY_OP, axis, onednn_engine, ctx.GetPlace(), dout, y, dx, 1.0f,
1.0f, 1.0f);
......@@ -253,10 +250,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} else {
broadcast_src_memory = reorder_src_memory_p;
}
}
// elementwise_mul & elementwise_div
else {
} else { // elementwise_mul & elementwise_div
std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::binary> binary_prim;
std::shared_ptr<dnnl::memory> post_op_memory;
......
......@@ -120,6 +120,142 @@ class Conv2DFusionOp : public operators::ConvOp {
ctx->SetOutputsDim("Outputs", output_shapes);
}
}
std::vector<int64_t> ComputeOutputShape(
framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations =
ctx->Attrs().Get<std::vector<int>>("dilations");
int dilation_size = dilations.size();
for (int i = 0; i < dilation_size; ++i) {
PADDLE_ENFORCE_GT(
dilations[i], 0,
platform::errors::InvalidArgument(
"The dilation of Op(Conv) should be larget than 0, but received "
"dilation is %d.",
dilations[i]));
}
const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, true,
platform::errors::InvalidArgument(
"The input of Op(Conv) should be a 4-D or 5-D Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(), in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
platform::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(Conv) should be equal. But received: the input's shape is "
"[%s], "
"the input's dimension is %d; the filter's shape is [%s], "
"the filter's dimension is %d.",
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
int stride_size = strides.size();
for (int i = 0; i < stride_size; ++i) {
PADDLE_ENFORCE_GT(
strides[i], 0,
platform::errors::InvalidArgument(
"The stride of Op(Conv) should be larget than 0, but received "
"stride is %d.",
strides[i]));
}
int in_sub_stride_size = in_dims.size() - stride_size;
PADDLE_ENFORCE_EQ(
in_dims.size(), strides.size() + 2U,
platform::errors::InvalidArgument(
"The difference of input's dimension and Attr(strides)'s "
"length must be euqal to 2 for Op(Conv). "
"But received: input's dimension is %d, input's shape is [%s]; "
"Attr(stride)'s length is %d, Attr(stride) is [%s]; "
"difference of input's dimention and Attr(strides)'s length = %u.",
in_dims.size(), in_dims, strides.size(), phi::make_ddim(strides),
in_sub_stride_size));
const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
PADDLE_ENFORCE_EQ(
input_channels, filter_dims[1] * groups,
platform::errors::InvalidArgument(
"The number of input's channels should be equal to filter's "
"channels "
"* groups for Op(Conv). But received: the input's channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d, the data_format is %s. "
"The error may come from wrong data_format setting.",
input_channels, in_dims, filter_dims[1], filter_dims, groups,
data_format));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0,
platform::errors::InvalidArgument(
"The number of output's channels (filter's first dimension) of "
"Op(Conv) should be divided by groups. But received: "
"the output channels is %d, the filter's shape is [%s], "
"the groups is %d.",
filter_dims[0], filter_dims, groups));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GT(
filter_dims[0], 0,
platform::errors::InvalidArgument(
"the size of filter at axis 0 should be greater than 0"));
}
framework::DDim in_data_dims;
if (channel_last) {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (!channel_last) {
output_shape.push_back(filter_dims[0]);
}
for (int i = 0; i < in_data_dims.size(); ++i) {
if ((!ctx->IsRuntime()) &&
(in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(
ConvOutputSize(in_data_dims[i], filter_data_dims[i], dilations[i],
paddings[2 * i], paddings[2 * i + 1], strides[i]));
}
}
if (channel_last) {
output_shape.push_back(filter_dims[0]);
}
return output_shape;
}
};
// TODO(qingqing): add gradient operator for conv2d_fusion
......
......@@ -14,10 +14,11 @@ limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -29,18 +30,6 @@ class GeluOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of GeluOp should not be null.", "X"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(%s) of GeluOp should not be null.", "Out"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -156,13 +145,10 @@ class GeluGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(gelu, GeluInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(gelu, ops::GeluOp, ops::GeluOpMaker,
ops::GeluGradOpMaker<paddle::framework::OpDesc>,
ops::GeluGradOpMaker<paddle::imperative::OpBase>);
ops::GeluGradOpMaker<paddle::imperative::OpBase>,
GeluInferShapeFunctor);
REGISTER_OPERATOR(gelu_grad, ops::GeluGradOp);
REGISTER_OP_CPU_KERNEL(
gelu, ops::GeluKernel<paddle::platform::CPUDeviceContext, float>,
ops::GeluKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
gelu_grad, ops::GeluGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GeluGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -15,7 +15,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -30,7 +30,7 @@ limitations under the License. */
namespace f = paddle::framework;
namespace p = paddle::platform;
USE_OP(gelu);
USE_OP_ITSELF(gelu);
USE_OP_DEVICE_KERNEL(gelu, NPU);
template <typename T>
......
......@@ -14,9 +14,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
......
......@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -60,31 +64,6 @@ namespace operators {
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "hsigmoid");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "hsigmoid");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "hsigmoid");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "hsigmoid");
OP_INOUT_CHECK(ctx->HasOutput("PreOut"), "Output", "PreOut", "hsigmoid");
auto with_prefetch = ctx->Attrs().Get<bool>("remote_prefetch");
if (with_prefetch) {
OP_INOUT_CHECK(ctx->HasOutput("W_Out"), "Output", "W_Out", "hsigmoid");
}
const int64_t input_dims = ctx->GetInputDim("X")[0];
const int64_t label_dims = ctx->GetInputDim("Label")[0];
PADDLE_ENFORCE_EQ(input_dims, label_dims,
platform::errors::InvalidArgument(
"The first dimension of "
"input and label is expected to be the same. "
"But received input's first dimension is %d; "
"label's first dimension is %d.",
input_dims, label_dims));
std::vector<int64_t> output_shape({input_dims, 1});
ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -272,22 +251,14 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
ops::HierarchicalSigmoidGradMaker<paddle::framework::OpDesc>,
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>);
DECLARE_INFER_SHAPE_FUNCTOR(hierarchical_sigmoid,
HierarchicalSigmoidInferShapeFunctor,
PD_INFER_META(phi::HierarchicalSigmoidInferMeta));
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
ops::HierarchicalSigmoidGradMaker<paddle::framework::OpDesc>,
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>,
HierarchicalSigmoidInferShapeFunctor);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference,
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <iterator>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
using framework::LoDTensor;
static std::vector<int64_t> PathToRows(const LoDTensor& path) {
std::set<int64_t> rows;
const int64_t* paths = path.data<int64_t>();
for (int64_t i = 0; i < path.numel(); ++i) {
int64_t row = paths[i];
if (row < 0) {
continue;
}
rows.emplace(row);
}
return std::vector<int64_t>(rows.begin(), rows.end());
}
template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X",
"HierarchicalSigmoid");
auto& w = GET_DATA_SAFELY(ctx.Input<LoDTensor>("W"), "Input", "W",
"HierarchicalSigmoid");
auto* path = ctx.Input<LoDTensor>("PathTable");
auto* code = ctx.Input<LoDTensor>("PathCode");
auto& label = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Label"), "Input",
"Label", "HierarchicalSigmoid");
auto* bias = ctx.Input<LoDTensor>("Bias");
auto* out = ctx.Output<LoDTensor>("Out");
auto* pre_out = ctx.Output<LoDTensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch
bool is_custom = false;
if (path) {
is_custom = true;
}
int64_t code_length =
path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
int64_t batch_size = in.dims()[0];
LoDTensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>(
phi::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss.
phi::funcs::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
phi::funcs::RowwiseSum<DeviceContext, T> row_sum;
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
num_classes, label.template data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
*path, *code, label.template data<int64_t>()));
}
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(phi::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenMatrix<T>::From(*out);
if (bias) {
bit_code->Add(*bias, pre_out);
}
bit_code->Mul(pre_out, w, in);
// clip to [-40, 40]
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code->Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat.device(place) = sum_mat + out_mat;
}
};
template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X",
"HierarchicalSigmoidGrad");
auto& w = GET_DATA_SAFELY(ctx.Input<LoDTensor>("W"), "Input", "W",
"HierarchicalSigmoidGrad");
auto* path = ctx.Input<LoDTensor>("PathTable");
auto* code = ctx.Input<LoDTensor>("PathCode");
auto* in_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
bool is_sparse = ctx.Attr<bool>("is_sparse");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> zero;
auto& label = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Label"), "Input",
"Label", "HierarchicalSigmoidGrad");
auto& pre_out = GET_DATA_SAFELY(ctx.Input<LoDTensor>("PreOut"), "Input",
"PreOut", "HierarchicalSigmoidGrad");
auto& out_grad = GET_DATA_SAFELY(
ctx.Input<LoDTensor>(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "HierarchicalSigmoidGrad");
LoDTensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, in_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
bool is_custom = false;
if (path) {
is_custom = true;
}
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
num_classes, label.template data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
*path, *code, label.template data<int64_t>()));
}
// softrelu derivative
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto* pre_out_grad_data = pre_out_grad.data<T>();
auto* pre_out_data = pre_out.template data<T>();
auto n = pre_out.numel();
blas.VEXP(n, pre_out_data, pre_out_grad_data);
blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
for (int64_t i = 0; i < n; ++i) {
pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
}
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
auto* out_grad_data = out_grad.template data<T>();
int64_t dim0 = pre_out_grad.dims()[0];
int64_t dim1 = pre_out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
T tmp = out_grad_data[i];
blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1);
}
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
auto* bias_grad = ctx.Output<LoDTensor>(framework::GradVarName("Bias"));
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad);
}
if (!is_sparse) {
auto* w_grad = ctx.Output<LoDTensor>(framework::GradVarName("W"));
w_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
} else {
PADDLE_ENFORCE_NOT_NULL(path,
platform::errors::NotFound(
"Custom tree must be set for sparse mode!"));
framework::Vector<int64_t> real_rows = PathToRows(*path);
auto* w_grad = ctx.Output<phi::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows);
// Build a map of id -> row_index to speed up finding the index of one id
w_grad->set_height(w.dims()[0]);
auto* w_grad_value = w_grad->mutable_value();
framework::DDim temp_dim(w.dims());
temp_dim[0] = real_rows.size();
w_grad_value->mutable_data<T>(temp_dim, ctx.GetPlace());
zero(dev_ctx, w_grad_value, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
}
bit_code->MulGradError(pre_out_grad, w, in_grad);
}
};
} // namespace operators
} // namespace paddle
......@@ -16,7 +16,9 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -28,27 +30,6 @@ class HistogramOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "histogram");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "histogram");
const auto &nbins = ctx->Attrs().Get<int64_t>("bins");
const auto &minval = ctx->Attrs().Get<int>("min");
const auto &maxval = ctx->Attrs().Get<int>("max");
PADDLE_ENFORCE_GE(nbins, 1,
platform::errors::InvalidArgument(
"The bins should be greater than or equal to 1."
"But received nbins is %d",
nbins));
PADDLE_ENFORCE_GE(maxval, minval, platform::errors::InvalidArgument(
"max must be larger or equal to min."
"But received max is %d, min is %d",
maxval, minval));
ctx->SetOutputDim("Out", phi::make_ddim({nbins}));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
......@@ -81,7 +62,12 @@ class HistogramOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(histogram, HistogramInferShapeFunctor,
PD_INFER_META(phi::HistogramInferMeta));
REGISTER_OPERATOR(
histogram, ops::HistogramOp, ops::HistogramOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
HistogramInferShapeFunctor);
......@@ -323,6 +323,7 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker,
ops::BatchNormOpInferVarType,
ops::InplaceABNOpGradMaker<paddle::framework::OpDesc>,
......
......@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/kthvalue_op.h"
#include <memory>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -25,54 +26,6 @@ class KthvalueOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kthvalue");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "kthvalue");
OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "kthvalue");
auto input_dims = ctx->GetInputDim("X");
const int& dim_size = input_dims.size();
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE_LT(axis, dim_size,
paddle::platform::errors::InvalidArgument(
"the axis must be [-%d, %d), but received %d .",
dim_size, dim_size, axis));
PADDLE_ENFORCE_GE(axis, -dim_size,
paddle::platform::errors::InvalidArgument(
"the axis must be [-%d, %d), but received %d .",
dim_size, dim_size, axis));
if (axis < 0) axis += dim_size;
int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(
k, 1, paddle::platform::errors::InvalidArgument(
"the k in the kthvalue must >= 1, but received %d .", k));
PADDLE_ENFORCE_GE(input_dims.size(), 1,
paddle::platform::errors::InvalidArgument(
"input of kthvalue must have >= 1d shape"));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
input_dims[axis], k,
paddle::platform::errors::InvalidArgument(
"input of kthvalue must have >= %d columns in axis of %d", k,
axis));
}
bool keepdim = ctx->Attrs().Get<bool>("keepdim");
std::vector<int64_t> dimvec;
for (int64_t i = 0; i < axis; i++) {
dimvec.emplace_back(input_dims[i]);
}
if (keepdim) {
dimvec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < dim_size; i++) {
dimvec.emplace_back(input_dims[i]);
}
framework::DDim dims = phi::make_ddim(dimvec);
ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -155,20 +108,13 @@ class KthvalueGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(kthvalue, KthvalueInferShapeFunctor,
PD_INFER_META(phi::KthvalueInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(kthvalue, ops::KthvalueOp, ops::KthvalueOpMaker,
ops::KthvalueGradOpMaker<paddle::framework::OpDesc>,
ops::KthvalueGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
kthvalue, ops::KthvalueCPUKernel<paddle::platform::CPUPlace, float>,
ops::KthvalueCPUKernel<paddle::platform::CPUPlace, double>,
ops::KthvalueCPUKernel<paddle::platform::CPUPlace, int32_t>,
ops::KthvalueCPUKernel<paddle::platform::CPUPlace, int64_t>);
ops::KthvalueGradOpMaker<paddle::imperative::OpBase>,
KthvalueInferShapeFunctor);
REGISTER_OPERATOR(kthvalue_grad, ops::KthvalueOpGrad);
REGISTER_OP_CPU_KERNEL(
kthvalue_grad,
ops::KthvalueGradCPUKernel<paddle::platform::CPUPlace, float>,
ops::KthvalueGradCPUKernel<paddle::platform::CPUPlace, double>,
ops::KthvalueGradCPUKernel<paddle::platform::CPUPlace, int32_t>,
ops::KthvalueGradCPUKernel<paddle::platform::CPUPlace, int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/kthvalue_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
namespace paddle {
namespace operators {
int getBlockSize(int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
}
template <typename T>
bool SortKthvalue(const platform::CUDADeviceContext& ctx,
const framework::Tensor* input_tensor, const int64_t num_cols,
const int64_t num_rows, const int k,
framework::Tensor* out_tensor,
framework::Tensor* indices_tensor) {
auto cu_stream = ctx.stream();
framework::Tensor input_indices;
const std::vector<int64_t> dims = {num_rows, num_cols};
auto dim = phi::make_ddim(dims);
input_indices.Resize(dim);
input_indices.mutable_data<int64_t>(ctx.GetPlace());
size_t temp_storage_bytes = -1;
int block_size = getBlockSize(num_cols);
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
unsigned int grid_size = num_rows < maxGridDimX
? static_cast<unsigned int>(num_rows)
: maxGridDimX;
InitIndex<int64_t><<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols);
cub::CountingInputIterator<int64_t> counting_iter(0);
cub::TransformInputIterator<int64_t, SegmentOffsetIter,
cub::CountingInputIterator<int64_t>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
T* sorted_values_ptr;
int64_t* sorted_indices_ptr;
framework::Tensor temp_values, temp_indices;
const T* input = input_tensor->data<T>();
T* values = out_tensor->data<T>();
int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());
temp_values.Resize(dim);
temp_indices.Resize(dim);
sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
auto err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, input, sorted_values_ptr,
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
#ifdef __HIPCC__
if (err != hipSuccess) {
LOG(ERROR) << "KthvalueOP failed as could not launch "
"hipcub::DeviceSegmentedRadixSort::SortPairs, status: "
<< hipGetErrorString(err);
return false;
}
#else
if (err != cudaSuccess) {
LOG(ERROR) << "KthvalueOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairs, status: "
<< cudaGetErrorString(err);
return false;
}
#endif
framework::Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, input,
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
0, sizeof(T) * 8, cu_stream);
#ifdef __HIPCC__
if (err != hipSuccess) {
LOG(ERROR) << "KthvalueOP failed as could not launch "
"hipcub::DeviceSegmentedRadixSort::SortPairs, "
<< temp_storage_bytes << ", status: " << hipGetErrorString(err);
return false;
}
#else
if (err != cudaSuccess) {
LOG(ERROR) << "KthvalueOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairs, "
<< temp_storage_bytes << ", status: " << cudaGetErrorString(err);
return false;
}
#endif
auto& dev = *ctx.eigen_device();
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, k - 1};
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, 1};
auto e_indices = framework::EigenMatrix<int64_t>::From(*indices_tensor, dim);
auto e_tmp_indices = framework::EigenMatrix<int64_t>::From(
static_cast<const framework::Tensor>(temp_indices));
std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(1)};
dim = phi::make_ddim(odims);
auto e_values = framework::EigenMatrix<T>::From(*out_tensor, dim);
auto e_tmp_values = framework::EigenMatrix<T>::From(
static_cast<const framework::Tensor>(temp_values));
EigenSlice<std::decay_t<decltype(dev)>, int64_t, 2>::Eval(
dev, e_indices, e_tmp_indices, slice_indices, slice_sizes);
EigenSlice<std::decay_t<decltype(dev)>, T, 2>::Eval(
dev, e_values, e_tmp_values, slice_indices, slice_sizes);
return true;
}
template <typename DeviceContext, typename T>
class KthvalueOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out");
auto* indices = ctx.Output<framework::Tensor>("Indices");
int k = static_cast<int>(ctx.Attr<int>("k"));
int axis = static_cast<int>(ctx.Attr<int>("axis"));
bool keepdim = static_cast<bool>(ctx.Attr<bool>("keepdim"));
const auto& in_dims = input->dims();
if (axis < 0) axis += in_dims.size();
auto out_dims = output->dims();
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
if (axis == in_dims.size() - 1) {
const int64_t& input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
PADDLE_ENFORCE_EQ(SortKthvalue<T>(dev_ctx, input, input_width,
input_height, k, output, indices),
true, platform::errors::External(
"KthvalueOP: Error when use cub sorting"));
return;
} else {
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
if (!keepdim) {
std::vector<int> tmp_out_shape;
for (int i = 0; i < axis; i++) {
tmp_out_shape.emplace_back(in_dims[i]);
}
tmp_out_shape.emplace_back(1);
for (int i = axis + 1; i < in_dims.size(); i++) {
tmp_out_shape.emplace_back(in_dims[i]);
}
framework::DDim tmp_out_dims = phi::make_ddim(tmp_out_shape);
output->Resize(tmp_out_dims);
indices->Resize(tmp_out_dims);
}
framework::DDim trans_dims(in_dims);
framework::DDim trans_out_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
trans_out_dims[i] = in_dims[trans[i]];
}
trans_out_dims[in_dims.size() - 1] = 1;
framework::Tensor trans_input;
trans_input.mutable_data<T>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
&trans_input, trans);
framework::Tensor trans_ind, trans_out;
trans_ind.mutable_data<int64_t>(trans_out_dims, ctx.GetPlace());
trans_out.mutable_data<T>(trans_out_dims, ctx.GetPlace());
const int64_t input_height =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
PADDLE_ENFORCE_EQ(
SortKthvalue<T>(dev_ctx, &trans_input, input_width, input_height, k,
&trans_out, &trans_ind),
true,
platform::errors::External("KthvalueOP: Error when use cub sorting"));
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, trans_out,
output, trans);
if (!keepdim) {
output->Resize(out_dims);
indices->Resize(out_dims);
}
}
}
};
template <typename DeviceContext, typename T>
class KthvalueOpGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* x = context.Input<framework::Tensor>("X");
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<framework::Tensor>("Indices");
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
int axis = context.Attr<int>("axis");
int k = static_cast<int>(context.Attr<int>("k"));
const auto& in_dims = x->dims();
auto out_dims = indices->dims();
if (axis < 0) axis += in_dims.size();
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
const T* out_grad_data = out_grad->data<T>();
const int64_t* indices_data = indices->data<int64_t>();
int pre, n, post;
GetDims(in_dims, axis, &pre, &n, &post);
auto& dev_ctx = context.cuda_device_context();
int block_size = getBlockSize(post * k);
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1);
int grid_size = std::min(max_blocks, pre);
AssignGradWithAxis<T><<<grid_size, block_size, 64 * 4, dev_ctx.stream()>>>(
out_grad_data, indices_data, x_grad_data, pre, post, n, 1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
kthvalue,
ops::KthvalueOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::KthvalueOpCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::KthvalueOpCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::KthvalueOpCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
kthvalue_grad,
ops::KthvalueOpGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::KthvalueOpGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::KthvalueOpGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::KthvalueOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
此差异已折叠。
......@@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/log_softmax_op.h"
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -24,10 +27,6 @@ class LogSoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
return UnaryOpUnchangedInferShapeCheckAxis(ctx);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -123,18 +122,11 @@ class LogSoftmaxGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(log_softmax, LogSoftmaxInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMetaCheckAxis));
REGISTER_OPERATOR(log_softmax, ops::LogSoftmaxOp, ops::LogSoftmaxOpMaker,
ops::LogSoftmaxOpInferVarType,
ops::LogSoftmaxGradOpMaker<paddle::framework::OpDesc>,
ops::LogSoftmaxGradOpMaker<paddle::imperative::OpBase>);
ops::LogSoftmaxGradOpMaker<paddle::imperative::OpBase>,
LogSoftmaxInferShapeFunctor);
REGISTER_OPERATOR(log_softmax_grad, ops::LogSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
log_softmax,
ops::LogSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
log_softmax_grad,
ops::LogSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
此差异已折叠。
......@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
namespace operators {
......@@ -27,7 +28,7 @@ class LogSoftmaxNPUKernel : public framework::OpKernel<T> {
auto* X = ctx.Input<framework::Tensor>("X");
auto* Out = ctx.Output<framework::Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int axis = phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), rank);
Out->mutable_data<T>(ctx.GetPlace());
if (X->numel() != 0) {
......@@ -47,7 +48,7 @@ class LogSoftmaxGradNPUKernel : public framework::OpKernel<T> {
auto* dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
const int rank = dOut->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int axis = phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), rank);
// allocate memory on device.
dX->mutable_data<T>(ctx.GetPlace());
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册