未验证 提交 336160cf 编写于 作者: W wanghuancoder 提交者: GitHub

force sync batch norm grad sequential (#52268)

* force sync batch norm grad sequential
上级 551ff882
...@@ -26,3 +26,40 @@ paddle::Tensor conv2d_ad_func(const paddle::Tensor& input, ...@@ -26,3 +26,40 @@ paddle::Tensor conv2d_ad_func(const paddle::Tensor& input,
std::vector<int> dilations, std::vector<int> dilations,
int groups, int groups,
std::string data_format); std::string data_format);
std::tuple<paddle::Tensor,
paddle::Tensor&,
paddle::Tensor&,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor>
sync_batch_norm__ad_func(const paddle::Tensor& x,
paddle::Tensor& mean, // NOLINT
paddle::Tensor& variance, // NOLINT
const paddle::Tensor& scale,
const paddle::Tensor& bias,
bool is_test,
float momentum,
float epsilon,
std::string data_layout,
bool use_global_stats,
bool trainable_statistics);
namespace sparse {
std::tuple<paddle::Tensor,
paddle::Tensor&,
paddle::Tensor&,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor>
sync_batch_norm__ad_func(const paddle::Tensor& x,
paddle::Tensor& mean, // NOLINT
paddle::Tensor& variance, // NOLINT
const paddle::Tensor& scale,
const paddle::Tensor& bias,
bool is_test,
float momentum,
float epsilon,
std::string data_layout,
bool use_global_stats,
bool trainable_statistics);
} // namespace sparse
set(eager_manual_functions set(eager_manual_functions
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forwards/conv2d_fwd_function.cc ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forwards/conv2d_fwd_function.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forwards/sync_batch_norm_fwd_func.cc
PARENT_SCOPE) PARENT_SCOPE)
set(eager_manual_nodes set(eager_manual_nodes
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/conv2d_nodes.cc ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/conv2d_nodes.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc
PARENT_SCOPE) PARENT_SCOPE)
...@@ -204,3 +204,174 @@ class AddNGradNodeFinal : public egr::GradNodeBase { ...@@ -204,3 +204,174 @@ class AddNGradNodeFinal : public egr::GradNodeBase {
// Attributes // Attributes
}; };
class SyncBatchNormGradNode : public egr::GradNodeBase {
public:
SyncBatchNormGradNode() : egr::GradNodeBase() {}
SyncBatchNormGradNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num)
: egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}
~SyncBatchNormGradNode() override = default;
virtual paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
operator()(paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>& grads, // NOLINT
bool create_graph = false,
bool is_new_grad = false) override;
std::string name() override { return "SyncBatchNormGradNode"; }
void ClearTensorWrappers() override {
x_.clear();
scale_.clear();
bias_.clear();
saved_mean_.clear();
saved_variance_.clear();
reserve_space_.clear();
SetIsTensorWrappersCleared(true);
}
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::shared_ptr<SyncBatchNormGradNode>(
new SyncBatchNormGradNode(*this));
return copied_node;
}
// SetTensorWrapperX, SetTensorWrapperY, ...
void SetTensorWrapperx(const paddle::Tensor& x) {
x_ = egr::TensorWrapper(x, false);
}
void SetTensorWrapperscale(const paddle::Tensor& scale) {
scale_ = egr::TensorWrapper(scale, false);
}
void SetTensorWrapperbias(const paddle::Tensor& bias) {
bias_ = egr::TensorWrapper(bias, false);
}
void SetTensorWrappersaved_mean(const paddle::Tensor& saved_mean) {
saved_mean_ = egr::TensorWrapper(saved_mean, false);
}
void SetTensorWrappersaved_variance(const paddle::Tensor& saved_variance) {
saved_variance_ = egr::TensorWrapper(saved_variance, false);
}
void SetTensorWrapperreserve_space(const paddle::Tensor& reserve_space) {
reserve_space_ = egr::TensorWrapper(reserve_space, false);
}
// SetAttributes
void SetAttributemomentum(const float& momentum) { momentum_ = momentum; }
void SetAttributeepsilon(const float& epsilon) { epsilon_ = epsilon; }
void SetAttributedata_layout(const std::string& data_layout) {
data_layout_ = data_layout;
}
void SetAttributeis_test(const bool& is_test) { is_test_ = is_test; }
void SetAttributeuse_global_stats(const bool& use_global_stats) {
use_global_stats_ = use_global_stats;
}
void SetAttributetrainable_statistics(const bool& trainable_statistics) {
trainable_statistics_ = trainable_statistics;
}
private:
// TensorWrappers
egr::TensorWrapper x_;
egr::TensorWrapper scale_;
egr::TensorWrapper bias_;
egr::TensorWrapper saved_mean_;
egr::TensorWrapper saved_variance_;
egr::TensorWrapper reserve_space_;
// Attributes
float momentum_;
float epsilon_;
std::string data_layout_;
bool is_test_;
bool use_global_stats_;
bool trainable_statistics_;
};
namespace sparse {
class SyncBatchNormGradNode : public egr::GradNodeBase {
public:
SyncBatchNormGradNode() : egr::GradNodeBase() {}
SyncBatchNormGradNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num)
: egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}
~SyncBatchNormGradNode() override = default;
virtual paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
operator()(paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>& grads, // NOLINT
bool create_graph = false,
bool is_new_grad = false) override;
std::string name() override { return "SyncBatchNormGradNode"; }
void ClearTensorWrappers() override {
x_.clear();
scale_.clear();
bias_.clear();
saved_mean_.clear();
saved_variance_.clear();
reserve_space_.clear();
SetIsTensorWrappersCleared(true);
}
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::shared_ptr<SyncBatchNormGradNode>(
new SyncBatchNormGradNode(*this));
return copied_node;
}
// SetTensorWrapperX, SetTensorWrapperY, ...
void SetTensorWrapperx(const paddle::Tensor& x) {
x_ = egr::TensorWrapper(x, false);
}
void SetTensorWrapperscale(const paddle::Tensor& scale) {
scale_ = egr::TensorWrapper(scale, false);
}
void SetTensorWrapperbias(const paddle::Tensor& bias) {
bias_ = egr::TensorWrapper(bias, false);
}
void SetTensorWrappersaved_mean(const paddle::Tensor& saved_mean) {
saved_mean_ = egr::TensorWrapper(saved_mean, false);
}
void SetTensorWrappersaved_variance(const paddle::Tensor& saved_variance) {
saved_variance_ = egr::TensorWrapper(saved_variance, false);
}
void SetTensorWrapperreserve_space(const paddle::Tensor& reserve_space) {
reserve_space_ = egr::TensorWrapper(reserve_space, false);
}
// SetAttributes
void SetAttributemomentum(const float& momentum) { momentum_ = momentum; }
void SetAttributeepsilon(const float& epsilon) { epsilon_ = epsilon; }
void SetAttributedata_layout(const std::string& data_layout) {
data_layout_ = data_layout;
}
void SetAttributeis_test(const bool& is_test) { is_test_ = is_test; }
void SetAttributeuse_global_stats(const bool& use_global_stats) {
use_global_stats_ = use_global_stats;
}
void SetAttributetrainable_statistics(const bool& trainable_statistics) {
trainable_statistics_ = trainable_statistics;
}
private:
// TensorWrappers
egr::TensorWrapper x_;
egr::TensorWrapper scale_;
egr::TensorWrapper bias_;
egr::TensorWrapper saved_mean_;
egr::TensorWrapper saved_variance_;
egr::TensorWrapper reserve_space_;
// Attributes
float momentum_;
float epsilon_;
std::string data_layout_;
bool is_test_;
bool use_global_stats_;
bool trainable_statistics_;
};
} // namespace sparse
...@@ -40,6 +40,8 @@ class UniqueNameGenerator { ...@@ -40,6 +40,8 @@ class UniqueNameGenerator {
// TODO(jiabin): Now we are using imperative tracer, move it here when we // TODO(jiabin): Now we are using imperative tracer, move it here when we
// deprecate imperative. // deprecate imperative.
class GradNodeBase;
class Controller { class Controller {
public: public:
static Controller& Instance() { return *controller_; } static Controller& Instance() { return *controller_; }
...@@ -119,6 +121,18 @@ class Controller { ...@@ -119,6 +121,18 @@ class Controller {
void ClearFinalBackwardHooks() { final_backward_hooks_.clear(); } void ClearFinalBackwardHooks() { final_backward_hooks_.clear(); }
void ClearForceSequentialNodes() {
while (!force_sequential_nodes_.empty()) {
force_sequential_nodes_.pop();
}
}
void PushBackForceSequentialNodes(GradNodeBase* node) {
force_sequential_nodes_.push(node);
}
std::queue<GradNodeBase*> GetForceSequentialNodes() {
return force_sequential_nodes_;
}
private: private:
Controller() = default; Controller() = default;
static Controller* controller_; static Controller* controller_;
...@@ -132,6 +146,7 @@ class Controller { ...@@ -132,6 +146,7 @@ class Controller {
std::vector<std::vector<std::unordered_map<int, int>>>> std::vector<std::vector<std::unordered_map<int, int>>>>
custom_edges_slot_map_; custom_edges_slot_map_;
std::vector<std::shared_ptr<VoidHook>> final_backward_hooks_; std::vector<std::shared_ptr<VoidHook>> final_backward_hooks_;
std::queue<GradNodeBase*> force_sequential_nodes_;
DISABLE_COPY_AND_ASSIGN(Controller); DISABLE_COPY_AND_ASSIGN(Controller);
}; };
......
...@@ -57,6 +57,7 @@ black_ops_list = [ ...@@ -57,6 +57,7 @@ black_ops_list = [
"conv2d_grad_grad", "conv2d_grad_grad",
"add_n", "add_n",
"add_n_grad", "add_n_grad",
"sync_batch_norm_",
] ]
......
...@@ -111,6 +111,22 @@ std::vector<paddle::Tensor> RunBackward( ...@@ -111,6 +111,22 @@ std::vector<paddle::Tensor> RunBackward(
const std::vector<paddle::Tensor>& no_grad_vars = {}) { const std::vector<paddle::Tensor>& no_grad_vars = {}) {
VLOG(3) << "Start Backward"; VLOG(3) << "Start Backward";
std::queue<GradNodeBase*> force_sequential_nodes_forward_queue =
egr::Controller::Instance().GetForceSequentialNodes();
egr::Controller::Instance().ClearForceSequentialNodes();
std::deque<GradNodeBase*> force_sequential_nodes_queue;
std::set<GradNodeBase*> force_sequential_nodes_set;
std::set<GradNodeBase*> ready_force_sequential_nodes;
auto force_sequential_nodes_size =
force_sequential_nodes_forward_queue.size();
for (size_t i = 0; i < force_sequential_nodes_size; ++i) {
force_sequential_nodes_set.insert(
force_sequential_nodes_forward_queue.front());
force_sequential_nodes_queue.push_front(
force_sequential_nodes_forward_queue.front());
force_sequential_nodes_forward_queue.pop();
}
// *Gradient Hook should happen at node-level // *Gradient Hook should happen at node-level
// *Inplace version check should perform at node-level // *Inplace version check should perform at node-level
// *Cross-batch accumulation happens at forward pass // *Cross-batch accumulation happens at forward pass
...@@ -355,12 +371,34 @@ std::vector<paddle::Tensor> RunBackward( ...@@ -355,12 +371,34 @@ std::vector<paddle::Tensor> RunBackward(
"Node's in-degree cannot be negative.", "Node's in-degree cannot be negative.",
next_node->name())); next_node->name()));
if (node_in_degree_map[next_node] == 0) { auto add_next_node_func = [&node_in_degree_map,
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) { &queue](GradNodeBase* next_node) {
queue.push_front(std::move(next_node)); if (node_in_degree_map[next_node] == 0) {
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
}
};
if (force_sequential_nodes_set.count(next_node)) {
if (force_sequential_nodes_queue.front() == next_node) {
force_sequential_nodes_queue.pop_front();
add_next_node_func(next_node);
while (ready_force_sequential_nodes.count(
force_sequential_nodes_queue.front())) {
ready_force_sequential_nodes.erase(
force_sequential_nodes_queue.front());
add_next_node_func(force_sequential_nodes_queue.front());
force_sequential_nodes_queue.pop_front();
}
} else { } else {
queue.push_back(std::move(next_node)); ready_force_sequential_nodes.insert(next_node);
continue;
} }
} else {
add_next_node_func(next_node);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册