未验证 提交 d1e93be1 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Optimize Grad by prune useless branch (#47827)

* [Eager] Fix paddle.grad interface

* [Eager] Support minimum SubGraph for GeneralGrad

* Add needed_nodes to prune grad graph more thoroughly

* [Eager] Add grad_node_trans_mapping_ to record which grad_node has been transformed to AccumulationNode

* [Eager] Fix paddle.grad interface

* Polish code

* remove potential_stop_node

* Add endding_nodes to enhance genSugraph logic

* clear endding_nodes_

* polish code

* rename endding_nodes to endding_nades_

* Refactor grad interface

* Add register_hook case to fix coverage-ci

* Fix code format

* Refactor general_grad

* Add more code comments

* call clear directly to release GradSlotMeta

* fix a mistake

* fix matmul/ multiply kernel logic and optional input in yaml, fill zeros logic and so on.

* fix batch_norm_double_grad yaml optional config

* fix tanh_triple_grad yaml and kernels

* fix MultiplyTripleGradKernel optional logic

* fix merge mistake

* fix compile error

* remove legacy attr for bn

* polish code

* fix some kernel

* merge develop

* fix error

* remote log

* fix kernel with full like

* hide value log behind

* hide value log behind

* fix matmul_triple grad
Co-authored-by: NWeilong Wu <veyron_wu@163.com>
上级 0754e09d
...@@ -35,6 +35,7 @@ ops_to_fill_zero_for_empty_grads = set( ...@@ -35,6 +35,7 @@ ops_to_fill_zero_for_empty_grads = set(
"multiply_triple_grad", "multiply_triple_grad",
"conv2d_grad_grad", "conv2d_grad_grad",
"batch_norm_double_grad", "batch_norm_double_grad",
"tanh_grad",
"tanh_double_grad", "tanh_double_grad",
"tanh_triple_grad", "tanh_triple_grad",
"sin_double_grad", "sin_double_grad",
......
...@@ -230,7 +230,7 @@ FORWARD_FUNCTION_TEMPLATE = """ ...@@ -230,7 +230,7 @@ FORWARD_FUNCTION_TEMPLATE = """
AFTER_LOG_PRINT_TEMPLATE = """ AFTER_LOG_PRINT_TEMPLATE = """
if(VLOG_IS_ON(4)){{ if(VLOG_IS_ON(4)){{
const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s], Output: [%s] }} \"; const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s], \\n Output: [%s] }} \";
{} {}
VLOG(4) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str, output_str); VLOG(4) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str, output_str);
}} }}
......
...@@ -173,9 +173,10 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -173,9 +173,10 @@ std::vector<paddle::experimental::Tensor> RunBackward(
node_input_buffers_dict[grad_node] = node_input_buffers_dict[grad_node] =
std::make_unique<GradTensorHolder>(grad_node->InputMeta()); std::make_unique<GradTensorHolder>(grad_node->InputMeta());
} }
bool copy_from_grad_t =
grad_tensors.size() > 0 && grad_tensors[i].initialized(); // copy grad tensor since we should totally run grad without affect forward
if (copy_from_grad_t) { // value
if (grad_tensors.size() > 0 && grad_tensors[i].initialized()) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
grad_tensors.size() == tensors.size(), grad_tensors.size() == tensors.size(),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
...@@ -357,16 +358,6 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -357,16 +358,6 @@ std::vector<paddle::experimental::Tensor> RunBackward(
"Node's in-degree cannot be negative.", "Node's in-degree cannot be negative.",
next_node->name())); next_node->name()));
if (is_general_grad) {
if (node_in_degree_map[next_node] == 0 &&
GeneralGrad::Instance().IsNeededNodes(next_node)) {
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
}
} else {
if (node_in_degree_map[next_node] == 0) { if (node_in_degree_map[next_node] == 0) {
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) { if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node)); queue.push_front(std::move(next_node));
...@@ -377,7 +368,6 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -377,7 +368,6 @@ std::vector<paddle::experimental::Tensor> RunBackward(
} }
} }
} }
}
VLOG(7) << "Run Backward Final hook size: " VLOG(7) << "Run Backward Final hook size: "
<< egr::Controller::Instance().FinalBackwardHooks().size(); << egr::Controller::Instance().FinalBackwardHooks().size();
......
...@@ -51,6 +51,10 @@ class GeneralGrad { ...@@ -51,6 +51,10 @@ class GeneralGrad {
for (size_t i = 0; i < num_inputs; i++) { for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta = AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]); EagerUtils::unsafe_autograd_meta(inputs[i]);
PADDLE_ENFORCE_NOT_NULL(
auto_grad_meta,
paddle::platform::errors::Fatal(
"We got %s:[%d] 's autograd meta is NULL.", msg, i));
auto* target_node = auto_grad_meta->GetMutableGradNode().get(); auto* target_node = auto_grad_meta->GetMutableGradNode().get();
if (orig_to_copied_node_map_.count(target_node)) { if (orig_to_copied_node_map_.count(target_node)) {
...@@ -82,10 +86,13 @@ class GeneralGrad { ...@@ -82,10 +86,13 @@ class GeneralGrad {
// input_target_nodes // input_target_nodes
void PurifyPotentialStartUpNodes() { void PurifyPotentialStartUpNodes() {
VLOG(6) << "Running in PurifyPotentialStartUpNodes"; VLOG(6) << "Running in PurifyPotentialStartUpNodes";
if (input_target_nodes_inputmeta_map_.empty()) return; if (input_target_nodes_inputmeta_map_.empty()) {
VLOG(6) << "No input target nodes found, skip.";
return;
}
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased; std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : potential_startup_nodes_) { for (auto startup_node : potential_startup_nodes_) {
auto iter = input_target_nodes_inputmeta_map_.find(startup_op); auto iter = input_target_nodes_inputmeta_map_.find(startup_node);
if (iter != input_target_nodes_inputmeta_map_.end()) { if (iter != input_target_nodes_inputmeta_map_.end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first); potential_startup_nodes_to_be_erased.emplace(iter->first);
} }
...@@ -157,11 +164,11 @@ class GeneralGrad { ...@@ -157,11 +164,11 @@ class GeneralGrad {
potential_startup_nodes_.erase(node); potential_startup_nodes_.erase(node);
} }
} }
} } // TODO(jiabin): May we need some check here.
} }
// Get Graph Info Betweent input target GradNode and outputs, // Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes_, potential_startup_nodes_ // record depending_nodes_
void GetGraphInfoBetweenTargets(const std::deque<GradNodeBase*>& init_queue) { void GetGraphInfoBetweenTargets(const std::deque<GradNodeBase*>& init_queue) {
VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
...@@ -227,7 +234,7 @@ class GeneralGrad { ...@@ -227,7 +234,7 @@ class GeneralGrad {
std::make_shared<paddle::experimental::Tensor>(target_result); std::make_shared<paddle::experimental::Tensor>(target_result);
} }
} }
} } // TODO(jiabin): Some check here.
} }
void SetResultForEnddingNodes( void SetResultForEnddingNodes(
...@@ -319,21 +326,22 @@ class GeneralGrad { ...@@ -319,21 +326,22 @@ class GeneralGrad {
void SetNodeToAccumulationNode(GradNodeBase* node) { void SetNodeToAccumulationNode(GradNodeBase* node) {
if (dynamic_cast<egr::GradNodeAccumulation*>(node)) return; if (dynamic_cast<egr::GradNodeAccumulation*>(node)) return;
if (!(depending_nodes_)[node].empty()) { if (!(depending_nodes_)[node].empty()) {
// Find precedding_nodes of current node.
auto precedding_nodes = (depending_nodes_)[node]; auto precedding_nodes = (depending_nodes_)[node];
for (auto pre_nodes : precedding_nodes) { for (auto pre_nodes : precedding_nodes) {
paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>& paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
pre_nodes_edges = pre_nodes->MutableOutputMeta(); pre_nodes_edges = pre_nodes->MutableOutputMeta();
for (size_t i = 0; i < pre_nodes_edges.size(); i++) { for (size_t i = 0; i < pre_nodes_edges.size(); i++) {
for (size_t j = 0; j < pre_nodes_edges[i].size(); j++) { for (size_t j = 0; j < pre_nodes_edges[i].size(); j++) {
auto edge_ = pre_nodes_edges[i][j].GetEdge(); const auto& edge_ = pre_nodes_edges[i][j].GetEdge();
if (edge_.GetGradNode() == node) { if (edge_.GetGradNode() == node) {
auto autograd_meta = egr::AutogradMeta(edge_);
Edge& pre_node_edge = pre_nodes_edges[i][j].GetMutableEdge(); Edge& pre_node_edge = pre_nodes_edges[i][j].GetMutableEdge();
if (copied_node_to_endding_node_map_.count(node)) { if (copied_node_to_endding_node_map_.count(node)) {
pre_node_edge.SetGradNode( pre_node_edge.SetGradNode(
copied_node_to_endding_node_map_[node]); copied_node_to_endding_node_map_[node]);
} else { } else {
auto autograd_meta = egr::AutogradMeta(edge_);
std::shared_ptr<GradNodeBase> shared_grad_node_accumulation = std::shared_ptr<GradNodeBase> shared_grad_node_accumulation =
std::make_shared<egr::GradNodeAccumulation>(&autograd_meta); std::make_shared<egr::GradNodeAccumulation>(&autograd_meta);
pre_node_edge.SetGradNode(shared_grad_node_accumulation); pre_node_edge.SetGradNode(shared_grad_node_accumulation);
...@@ -361,7 +369,7 @@ class GeneralGrad { ...@@ -361,7 +369,7 @@ class GeneralGrad {
grad_node->SetGradientHookFuntions( grad_node->SetGradientHookFuntions(
node->GetGradientHookFuntions()); node->GetGradientHookFuntions());
} }
} } // or this node has no need to change
} }
} }
} }
...@@ -381,12 +389,10 @@ class GeneralGrad { ...@@ -381,12 +389,10 @@ class GeneralGrad {
} }
visited.insert(node); visited.insert(node);
if (IsInputTargetNodes(node)) { if (IsInputTargetNodes(node) && IsEnddingNodes(node)) {
if (IsEnddingNodes(node)) {
SetNodeToAccumulationNode(node); SetNodeToAccumulationNode(node);
continue; continue;
} }
}
paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>& paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
meta = node->MutableOutputMeta(); meta = node->MutableOutputMeta();
...@@ -411,7 +417,17 @@ class GeneralGrad { ...@@ -411,7 +417,17 @@ class GeneralGrad {
continue; continue;
} }
// TODO(weilong): support prune logic deeper if (meta.size() != 1 && IsNeededNodes(node) &&
!IsNeededNodes(next_node.get()) && !IsEnddingNodes(node)) {
VLOG(3) << "Get stop edge from grad_node: " << node->name() << " : "
<< node << " to:" << next_node->name() << ", "
<< next_node.get() << " with output rank info: " << i
<< ", " << j;
// No need to compute grad from needed Nodes to no need Nodes
meta[i][j].SetStopGradient(true);
edge.Clear();
continue;
}
// Update BFS queue // Update BFS queue
queue_.push_back(next_node.get()); queue_.push_back(next_node.get());
...@@ -502,7 +518,8 @@ class GeneralGrad { ...@@ -502,7 +518,8 @@ class GeneralGrad {
// Save node and update mapping // Save node and update mapping
orig_to_copied_node_map_[orig_node.get()] = copied_node; orig_to_copied_node_map_[orig_node.get()] = copied_node;
copied_grad_nodes_.push_back(copied_node); copied_grad_nodes_.push_back(copied_node);
VLOG(3) << "Copied Node: " << orig_node->name() << " ptr: " << orig_node
<< " to ptr: " << copied_node;
return copied_node.get(); return copied_node.get();
} }
......
...@@ -99,6 +99,11 @@ void GradTensorHolder::add(size_t slot_id, ...@@ -99,6 +99,11 @@ void GradTensorHolder::add(size_t slot_id,
size_t rank, size_t rank,
const paddle::experimental::Tensor& t, const paddle::experimental::Tensor& t,
bool create_graph) { bool create_graph) {
if (!t.initialized()) {
VLOG(3) << "No need to do accumulate for uninitialized t.";
return;
} // TODO(jiabin): Remove this when we fix all kernel.
PADDLE_ENFORCE(slot_id < buffer_.size(), PADDLE_ENFORCE(slot_id < buffer_.size(),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Invalid slot_id for GradTensorHolder::add() " "Invalid slot_id for GradTensorHolder::add() "
......
...@@ -277,7 +277,58 @@ class EagerUtils { ...@@ -277,7 +277,58 @@ class EagerUtils {
} else { } else {
tensor_info_str += "Unknown"; tensor_info_str += "Unknown";
} }
if (VLOG_IS_ON(6)) { if (VLOG_IS_ON(11)) {
const char* TENSOR_PRINT_TEMPLATE =
"{Name: %s, Initialized: %d, Ptr: %d "
"TensorInfo: [ %s ], Value:[ %s ], ADInfo:[ %s ]}";
auto* ad_meta = nullable_autograd_meta(t);
if (ad_meta && (ad_meta->WeakGrad().lock().get())) {
std::string ad_info_str = "";
const char* AD_INFO_TEMPLATE =
"Grad: [ %s ], GradNode: [ %s ], StopGradient: [ %d ]";
ad_info_str += paddle::string::Sprintf(AD_INFO_TEMPLATE,
TensorStr(ad_meta->Grad()),
GradNodeStr(t),
ad_meta->StopGradient());
auto* data_ptr = dynamic_cast<phi::DenseTensor*>(t.impl().get());
if (t.is_initialized() && data_ptr) {
return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE,
tensor_name_str,
t.initialized(),
t.impl(),
tensor_info_str,
*data_ptr,
ad_info_str);
} else {
return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE,
tensor_name_str,
t.initialized(),
t.impl(),
tensor_info_str,
"None",
ad_info_str);
}
} else {
auto* data_ptr = dynamic_cast<phi::DenseTensor*>(t.impl().get());
if (t.is_initialized() && data_ptr) {
return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE,
tensor_name_str,
t.initialized(),
t.impl(),
tensor_info_str,
*data_ptr,
"None");
} else {
return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE,
tensor_name_str,
t.initialized(),
t.impl(),
tensor_info_str,
"None",
"None");
}
}
} else if (VLOG_IS_ON(6)) {
const char* TENSOR_PRINT_TEMPLATE = const char* TENSOR_PRINT_TEMPLATE =
"{Name: %s, Initialized: %d, Ptr: %d " "{Name: %s, Initialized: %d, Ptr: %d "
"TensorInfo: [ %s ], ADInfo:[ %s ]}"; "TensorInfo: [ %s ], ADInfo:[ %s ]}";
......
...@@ -187,6 +187,7 @@ ...@@ -187,6 +187,7 @@
param : [x, x] param : [x, x]
kernel : kernel :
func : cos_double_grad func : cos_double_grad
optional: grad_out
backward : cos_triple_grad backward : cos_triple_grad
inplace : (grad_x_grad -> grad_out_grad) inplace : (grad_x_grad -> grad_out_grad)
...@@ -211,6 +212,7 @@ ...@@ -211,6 +212,7 @@
param : [x, x, grad_x_grad_forward] param : [x, x, grad_x_grad_forward]
kernel : kernel :
func : cos_triple_grad func : cos_triple_grad
optional: grad_out_forward, grad_x_grad_forward, grad_out_grad_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad) inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_op : cosh_grad - backward_op : cosh_grad
...@@ -872,6 +874,7 @@ ...@@ -872,6 +874,7 @@
param : [x, x] param : [x, x]
kernel : kernel :
func : sin_double_grad func : sin_double_grad
optional: grad_out
backward : sin_triple_grad backward : sin_triple_grad
inplace : (grad_x_grad -> grad_out_grad) inplace : (grad_x_grad -> grad_out_grad)
...@@ -896,6 +899,7 @@ ...@@ -896,6 +899,7 @@
param : [x, x, grad_x_grad_forward] param : [x, x, grad_x_grad_forward]
kernel : kernel :
func : sin_triple_grad func : sin_triple_grad
optional: grad_out_forward, grad_x_grad_forward, grad_out_grad_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad) inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_op : sinh_grad - backward_op : sinh_grad
...@@ -1054,6 +1058,7 @@ ...@@ -1054,6 +1058,7 @@
kernel : kernel :
func : tanh_triple_grad func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad) inplace : (grad_x_grad_forward -> grad_out_forward_grad)
optional : grad_out_new_grad, grad_out_grad_grad
- backward_op : thresholded_relu_grad - backward_op : thresholded_relu_grad
forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out) forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out)
......
...@@ -124,7 +124,7 @@ ...@@ -124,7 +124,7 @@
kernel : kernel :
func : batch_norm_grad_grad func : batch_norm_grad_grad
data_type : x data_type : x
optional : out_mean, out_variance optional : out_mean, out_variance, grad_x_grad, grad_scale_grad, grad_bias_grad
inplace : (grad_out -> grad_out_grad) inplace : (grad_out -> grad_out_grad)
- backward_op : batch_norm_grad - backward_op : batch_norm_grad
...@@ -856,7 +856,7 @@ ...@@ -856,7 +856,7 @@
param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y] param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel : kernel :
func : matmul_triple_grad func : matmul_triple_grad
optional : grad_x_grad, grad_y_grad, grad_grad_out_grad optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad
- backward_op : max_grad - backward_op : max_grad
forward: max (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out) forward: max (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out)
...@@ -1024,10 +1024,10 @@ ...@@ -1024,10 +1024,10 @@
output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad) output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad)
infer_meta : infer_meta :
func : GeneralQuinaryGradInferMeta func : GeneralQuinaryGradInferMeta
param : [x, y, fwd_grad_out, x, y] param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel : kernel :
func : multiply_triple_grad func : multiply_triple_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_grad_out_grad optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad
- backward_op : nearest_interp_grad - backward_op : nearest_interp_grad
forward : nearest_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) -> Tensor(output) forward : nearest_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) -> Tensor(output)
......
...@@ -83,7 +83,7 @@ void ReluDoubleGradKernel(const Context& dev_ctx, ...@@ -83,7 +83,7 @@ void ReluDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void SinDoubleGradKernel(const Context& dev_ctx, void SinDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx, const DenseTensor& ddx,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* ddout); DenseTensor* ddout);
...@@ -91,7 +91,7 @@ void SinDoubleGradKernel(const Context& dev_ctx, ...@@ -91,7 +91,7 @@ void SinDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void CosDoubleGradKernel(const Context& dev_ctx, void CosDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx, const DenseTensor& ddx,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* ddout); DenseTensor* ddout);
...@@ -109,8 +109,8 @@ void TanhTripleGradKernel(const Context& dev_ctx, ...@@ -109,8 +109,8 @@ void TanhTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& dout, const DenseTensor& dout,
const DenseTensor& ddx, const DenseTensor& ddx,
const DenseTensor& d_dout_new, const paddle::optional<DenseTensor>& d_dout_new,
const DenseTensor& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_out_new, DenseTensor* d_out_new,
DenseTensor* d_dout, DenseTensor* d_dout,
DenseTensor* d_ddx); DenseTensor* d_ddx);
...@@ -118,10 +118,10 @@ void TanhTripleGradKernel(const Context& dev_ctx, ...@@ -118,10 +118,10 @@ void TanhTripleGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void SinTripleGradKernel(const Context& dev_ctx, void SinTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx, const paddle::optional<DenseTensor>& ddx,
const DenseTensor& d_dx_new, const DenseTensor& d_dx_new,
const DenseTensor& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_x_new, DenseTensor* d_x_new,
DenseTensor* d_dout, DenseTensor* d_dout,
DenseTensor* d_ddx); DenseTensor* d_ddx);
...@@ -129,10 +129,10 @@ void SinTripleGradKernel(const Context& dev_ctx, ...@@ -129,10 +129,10 @@ void SinTripleGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void CosTripleGradKernel(const Context& dev_ctx, void CosTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx, const paddle::optional<DenseTensor>& ddx,
const DenseTensor& d_dx_new, const DenseTensor& d_dx_new,
const DenseTensor& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_x_new, DenseTensor* d_x_new,
DenseTensor* d_dout, DenseTensor* d_dout,
DenseTensor* d_ddx); DenseTensor* d_ddx);
......
...@@ -64,7 +64,8 @@ void BatchNormGradKernel(const Context& dev_ctx, ...@@ -64,7 +64,8 @@ void BatchNormGradKernel(const Context& dev_ctx,
DenseTensor* bias_grad); DenseTensor* bias_grad);
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context& dev_ctx, void BatchNormDoubleGradKernel(
const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const paddle::optional<DenseTensor>& mean, const paddle::optional<DenseTensor>& mean,
...@@ -72,17 +73,17 @@ void BatchNormDoubleGradKernel(const Context& dev_ctx, ...@@ -72,17 +73,17 @@ void BatchNormDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
const DenseTensor& y_grad, const DenseTensor& y_grad,
const DenseTensor& x_grad_grad, const paddle::optional<DenseTensor>& x_grad_grad,
const DenseTensor& scale_grad_grad, const paddle::optional<DenseTensor>& scale_grad_grad,
const DenseTensor& bias_grad_grad, const paddle::optional<DenseTensor>& bias_grad_grad,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout, const std::string& data_layout,
bool is_test, bool is_test,
bool use_global_stats, bool use_global_stats,
bool trainable_statistics, bool trainable_statistics,
bool fuse_with_relu,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* scale_grad, DenseTensor* scale_grad,
DenseTensor* y_grad_grad); DenseTensor* y_grad_grad);
} // namespace phi } // namespace phi
...@@ -334,7 +334,8 @@ void BatchNormGradKernel(const Context& dev_ctx, ...@@ -334,7 +334,8 @@ void BatchNormGradKernel(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context& ctx, void BatchNormDoubleGradKernel(
const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const paddle::optional<DenseTensor>& mean, const paddle::optional<DenseTensor>& mean,
...@@ -342,9 +343,9 @@ void BatchNormDoubleGradKernel(const Context& ctx, ...@@ -342,9 +343,9 @@ void BatchNormDoubleGradKernel(const Context& ctx,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
const DenseTensor& y_grad, const DenseTensor& y_grad,
const DenseTensor& x_grad_grad, const paddle::optional<DenseTensor>& x_grad_grad,
const DenseTensor& scale_grad_grad, const paddle::optional<DenseTensor>& scale_grad_grad,
const DenseTensor& bias_grad_grad, const paddle::optional<DenseTensor>& bias_grad_grad,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout_str, const std::string& data_layout_str,
...@@ -369,9 +370,9 @@ void BatchNormDoubleGradKernel(const Context& ctx, ...@@ -369,9 +370,9 @@ void BatchNormDoubleGradKernel(const Context& ctx,
const auto data_layout = phi::StringToDataLayout(data_layout_str); const auto data_layout = phi::StringToDataLayout(data_layout_str);
const auto* ddX = &x_grad_grad; const auto* ddX = x_grad_grad.get_ptr();
const auto* ddScale = &scale_grad_grad; const auto* ddScale = scale_grad_grad.get_ptr();
const auto* ddBias = &bias_grad_grad; const auto* ddBias = bias_grad_grad.get_ptr();
auto* dX = x_grad; auto* dX = x_grad;
auto* dScale = scale_grad; auto* dScale = scale_grad;
......
...@@ -108,6 +108,9 @@ PD_REGISTER_KERNEL(full_like, ...@@ -108,6 +108,9 @@ PD_REGISTER_KERNEL(full_like,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -47,8 +47,8 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, ...@@ -47,8 +47,8 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
const paddle::optional<DenseTensor>& ddx, const paddle::optional<DenseTensor>& ddx,
const paddle::optional<DenseTensor>& ddy, const paddle::optional<DenseTensor>& ddy,
const DenseTensor& d_dx, const paddle::optional<DenseTensor>& d_dx,
const DenseTensor& d_dy, const paddle::optional<DenseTensor>& d_dy,
const paddle::optional<DenseTensor>& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
int axis, int axis,
DenseTensor* d_x, DenseTensor* d_x,
......
...@@ -125,15 +125,24 @@ struct SinDoubleGradFunctor : public BaseActivationFunctor<T> { ...@@ -125,15 +125,24 @@ struct SinDoubleGradFunctor : public BaseActivationFunctor<T> {
// calculate d2x first, so d2d1y can inplace d2d1x // calculate d2x first, so d2d1y can inplace d2d1x
auto d2x = EigenVector<T>::Flatten( auto d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "d2x", "SinDoubleGrad")); GET_DATA_SAFELY(dX, "Output", "d2x", "SinDoubleGrad"));
if (dX) {
if (dOut) {
auto d1y = EigenVector<T>::Flatten( auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "d1y", "SinDoubleGrad")); GET_DATA_SAFELY(dOut, "Output", "d1y", "SinDoubleGrad"));
d2x.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()) * d1y; d2x.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()) * d1y;
} else {
d2x.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()) * static_cast<T>(0);
}
}
// calculate d2d1y // calculate d2d1y
if (ddOut) {
auto d2d1y = EigenVector<T>::Flatten( auto d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "SinDoubleGrad")); GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "SinDoubleGrad"));
d2d1y.device(*d) = d2d1x * x.unaryExpr(Cosine<T>()); d2d1y.device(*d) = d2d1x * x.unaryExpr(Cosine<T>());
} }
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
...@@ -167,28 +176,71 @@ struct SinTripleGradFunctor : public BaseActivationFunctor<T> { ...@@ -167,28 +176,71 @@ struct SinTripleGradFunctor : public BaseActivationFunctor<T> {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto x = EigenVector<T>::Flatten( auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "x", "SinTripleGrad")); GET_DATA_SAFELY(X, "Input", "x", "SinTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "SinTripleGrad"));
if (d_x_New) {
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "SinTripleGrad"));
if (dOut && ddX && d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten( auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad")); GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
auto d1y = EigenVector<T>::Flatten( auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad")); GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten( auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad")); GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "SinTripleGrad"));
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "SinTripleGrad"));
d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d2d1x * d3d2x - d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d2d1x * d3d2x -
x.unaryExpr(Sine<T>()) * d2d1x * d3d2d1y; x.unaryExpr(Sine<T>()) * d2d1x * d3d2d1y;
} else if (!dOut && ddX && d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
d3x.device(*d) = -x.unaryExpr(Sine<T>()) * d2d1x * d3d2d1y;
} else if (dOut && ddX && !d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d2d1x * d3d2x;
} else {
d3x.device(*d) = x * static_cast<T>(0);
}
}
if (d_d_Out) {
auto d3d1y = EigenVector<T>::Flatten( auto d3d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "SinTripleGrad")); GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "SinTripleGrad"));
if (ddX) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
d3d1y.device(*d) = -x.unaryExpr(Sine<T>()) * d2d1x * d3d2x; d3d1y.device(*d) = -x.unaryExpr(Sine<T>()) * d2d1x * d3d2x;
} else {
d3d1y.device(*d) = static_cast<T>(0) * x;
}
}
if (d_DDx) {
auto d3d2d1x = EigenVector<T>::Flatten( auto d3d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "SinTripleGrad")); GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "SinTripleGrad"));
if (dOut && d_DDOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d1y * d3d2x + d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d1y * d3d2x +
x.unaryExpr(Cosine<T>()) * d3d2d1y; x.unaryExpr(Cosine<T>()) * d3d2d1y;
} else if (dOut && !d_DDOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d1y * d3d2x;
} else if (!dOut && d_DDOut) {
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
d3d2d1x.device(*d) = x.unaryExpr(Cosine<T>()) * d3d2d1y;
} else {
d3d2d1x.device(*d) = x * static_cast<T>(0);
}
}
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut; return ActBwdOpFwdDeps::kDepOut;
...@@ -270,15 +322,23 @@ struct CosDoubleGradFunctor : public BaseActivationFunctor<T> { ...@@ -270,15 +322,23 @@ struct CosDoubleGradFunctor : public BaseActivationFunctor<T> {
// calculate d2x first, so d2d1y can inplace d2d1x // calculate d2x first, so d2d1y can inplace d2d1x
auto d2x = EigenVector<T>::Flatten( auto d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "d2x", "CosDoubleGrad")); GET_DATA_SAFELY(dX, "Output", "d2x", "CosDoubleGrad"));
if (ddOut) {
if (dOut) {
auto d1y = EigenVector<T>::Flatten( auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "d1y", "CosDoubleGrad")); GET_DATA_SAFELY(dOut, "Output", "d1y", "CosDoubleGrad"));
d2x.device(*d) = -d2d1x * x.unaryExpr(Cosine<T>()) * d1y; d2x.device(*d) = -d2d1x * x.unaryExpr(Cosine<T>()) * d1y;
} else {
d2x.device(*d) = x * static_cast<T>(0);
}
}
if (dX) {
// calculate d2d1y // calculate d2d1y
auto d2d1y = EigenVector<T>::Flatten( auto d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "CosDoubleGrad")); GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "CosDoubleGrad"));
d2d1y.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()); d2d1y.device(*d) = -d2d1x * x.unaryExpr(Sine<T>());
} }
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
...@@ -297,28 +357,72 @@ struct CosTripleGradFunctor : public BaseActivationFunctor<T> { ...@@ -297,28 +357,72 @@ struct CosTripleGradFunctor : public BaseActivationFunctor<T> {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto x = EigenVector<T>::Flatten( auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "x", "CosTripleGrad")); GET_DATA_SAFELY(X, "Input", "x", "CosTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "CosTripleGrad"));
if (d_x_New) {
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad"));
if (dOut && ddX && d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten( auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad")); GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
auto d1y = EigenVector<T>::Flatten( auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad")); GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten( auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad")); GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "CosTripleGrad"));
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad"));
d3x.device(*d) = x.unaryExpr(Sine<T>()) * d1y * d2d1x * d3d2x - d3x.device(*d) = x.unaryExpr(Sine<T>()) * d1y * d2d1x * d3d2x -
x.unaryExpr(Cosine<T>()) * d2d1x * d3d2d1y; x.unaryExpr(Cosine<T>()) * d2d1x * d3d2d1y;
} else if (dOut && ddX && !d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
d3x.device(*d) = x.unaryExpr(Sine<T>()) * d1y * d2d1x * d3d2x;
} else if (!dOut && ddX && d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d2d1x * d3d2d1y;
} else {
d3x.device(*d) = static_cast<T>(0) * x;
}
}
if (d_d_Out) {
auto d3d1y = EigenVector<T>::Flatten( auto d3d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "CosTripleGrad")); GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "CosTripleGrad"));
if (ddX) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
d3d1y.device(*d) = -x.unaryExpr(Cosine<T>()) * d2d1x * d3d2x; d3d1y.device(*d) = -x.unaryExpr(Cosine<T>()) * d2d1x * d3d2x;
} else {
d3d1y.device(*d) = static_cast<T>(0) * x;
}
}
if (d_DDx) {
auto d3d2d1x = EigenVector<T>::Flatten( auto d3d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "CosTripleGrad")); GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "CosTripleGrad"));
if (dOut && d_DDOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d3d2x - d3d2d1x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d3d2x -
x.unaryExpr(Sine<T>()) * d3d2d1y; x.unaryExpr(Sine<T>()) * d3d2d1y;
} else if (!dOut && d_DDOut) {
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d3d2d1y;
} else if (dOut && !d_DDOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d3d2x;
} else {
d3d2d1x.device(*d) = static_cast<T>(0) * x;
}
}
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut; return ActBwdOpFwdDeps::kDepOut;
...@@ -1106,27 +1210,70 @@ struct TanhTripleGradFunctor : public BaseActivationFunctor<T> { ...@@ -1106,27 +1210,70 @@ struct TanhTripleGradFunctor : public BaseActivationFunctor<T> {
GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad")); GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad"));
auto dout = EigenVector<T>::Flatten( auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad")); GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad"));
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
auto d_dOutNew = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
if (d_Out_New) { if (d_Out_New) {
auto d_OutNew = EigenVector<T>::Flatten( auto d_OutNew = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad")); GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad"));
if (d_DDOut && d_dOut_New) {
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_OutNew.device(*d) = (static_cast<T>(-2) * out * ddx * d_ddOut) - d_OutNew.device(*d) = (static_cast<T>(-2) * out * ddx * d_ddOut) -
(static_cast<T>(2) * dout * ddx * d_dOutNew); (static_cast<T>(2) * dout * ddx * d_dOutNew);
} else if (d_DDOut && !d_dOut_New) {
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
d_OutNew.device(*d) = (static_cast<T>(-2) * out * ddx * d_ddOut);
} else if (!d_DDOut && d_dOut_New) {
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_OutNew.device(*d) = -(static_cast<T>(2) * dout * ddx * d_dOutNew);
} else {
d_OutNew.device(*d) = static_cast<T>(0) * out;
}
} }
if (d_d_Out) { if (d_d_Out) {
auto d_dOut = EigenVector<T>::Flatten( auto d_dOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad")); GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad"));
if (d_dOut_New) {
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_dOut.device(*d) = static_cast<T>(-2) * out * ddx * d_dOutNew; d_dOut.device(*d) = static_cast<T>(-2) * out * ddx * d_dOutNew;
} else {
d_dOut.device(*d) = static_cast<T>(0) * out;
}
} }
if (d_DDx) { if (d_DDx) {
auto d_ddx = EigenVector<T>::Flatten( auto d_ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad")); GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad"));
if (d_DDOut && d_dOut_New) {
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_ddx.device(*d) = (static_cast<T>(1) - (out * out)) * d_ddOut - d_ddx.device(*d) = (static_cast<T>(1) - (out * out)) * d_ddOut -
static_cast<T>(2) * out * dout * d_dOutNew; static_cast<T>(2) * out * dout * d_dOutNew;
} else if (d_DDOut && !d_dOut_New) {
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
d_ddx.device(*d) = (static_cast<T>(1) - (out * out)) * d_ddOut;
} else if (!d_DDOut && d_dOut_New) {
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_ddx.device(*d) = -static_cast<T>(2) * out * dout * d_dOutNew;
} else {
d_ddx.device(*d) = static_cast<T>(0) * ddx;
}
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { static constexpr ActBwdOpFwdDeps FwdDeps() {
......
...@@ -1295,7 +1295,8 @@ void BatchNormGradKernel(const Context &dev_ctx, ...@@ -1295,7 +1295,8 @@ void BatchNormGradKernel(const Context &dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context &ctx, void BatchNormDoubleGradKernel(
const Context &ctx,
const DenseTensor &x, const DenseTensor &x,
const DenseTensor &scale, const DenseTensor &scale,
const paddle::optional<DenseTensor> &mean, const paddle::optional<DenseTensor> &mean,
...@@ -1303,9 +1304,9 @@ void BatchNormDoubleGradKernel(const Context &ctx, ...@@ -1303,9 +1304,9 @@ void BatchNormDoubleGradKernel(const Context &ctx,
const DenseTensor &saved_mean, const DenseTensor &saved_mean,
const DenseTensor &saved_variance, const DenseTensor &saved_variance,
const DenseTensor &y_grad, const DenseTensor &y_grad,
const DenseTensor &x_grad_grad, const paddle::optional<DenseTensor> &x_grad_grad,
const DenseTensor &scale_grad_grad, const paddle::optional<DenseTensor> &scale_grad_grad,
const DenseTensor &bias_grad_grad, const paddle::optional<DenseTensor> &bias_grad_grad,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string &data_layout_str, const std::string &data_layout_str,
...@@ -1330,7 +1331,8 @@ void BatchNormDoubleGradKernel(const Context &ctx, ...@@ -1330,7 +1331,8 @@ void BatchNormDoubleGradKernel(const Context &ctx,
running_mean = mean.get_ptr(); running_mean = mean.get_ptr();
running_variance = variance.get_ptr(); running_variance = variance.get_ptr();
} }
paddle::operators::NormDoubleGradFunctor<Context, T>(ctx, paddle::operators::NormDoubleGradFunctor<Context, T>(
ctx,
data_layout, data_layout,
&x, &x,
&scale, &scale,
...@@ -1341,9 +1343,9 @@ void BatchNormDoubleGradKernel(const Context &ctx, ...@@ -1341,9 +1343,9 @@ void BatchNormDoubleGradKernel(const Context &ctx,
running_variance, running_variance,
epsilon, epsilon,
use_global_stats, use_global_stats,
&x_grad_grad, x_grad_grad.get_ptr(),
&scale_grad_grad, scale_grad_grad.get_ptr(),
&bias_grad_grad, bias_grad_grad.get_ptr(),
x_grad, x_grad,
scale_grad, scale_grad,
y_grad_grad); y_grad_grad);
......
...@@ -138,6 +138,8 @@ PD_REGISTER_KERNEL(full_like, ...@@ -138,6 +138,8 @@ PD_REGISTER_KERNEL(full_like,
int64_t, int64_t,
bool, bool,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -177,8 +177,8 @@ void TanhTripleGradKernel(const Context& dev_ctx, ...@@ -177,8 +177,8 @@ void TanhTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& dout, const DenseTensor& dout,
const DenseTensor& ddx, const DenseTensor& ddx,
const DenseTensor& d_dout_new, const paddle::optional<DenseTensor>& d_dout_new,
const DenseTensor& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_out_new, DenseTensor* d_out_new,
DenseTensor* d_dout, DenseTensor* d_dout,
DenseTensor* d_ddx) { DenseTensor* d_ddx) {
...@@ -199,8 +199,8 @@ void TanhTripleGradKernel(const Context& dev_ctx, ...@@ -199,8 +199,8 @@ void TanhTripleGradKernel(const Context& dev_ctx,
&out, &out,
&ddx, &ddx,
&dout, &dout,
&d_ddout, d_ddout.get_ptr(),
&d_dout_new, // input d_dout_new.get_ptr(), // input
d_dout, d_dout,
d_out_new, d_out_new,
d_ddx); // output d_ddx); // output
...@@ -597,49 +597,45 @@ void SquareDoubleGradKernel(const Context& dev_ctx, ...@@ -597,49 +597,45 @@ void SquareDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void SinDoubleGradKernel(const Context& dev_ctx, void SinDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx, const DenseTensor& ddx,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* ddout) { DenseTensor* ddout) {
if (dx) { if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx); dev_ctx.template Alloc<T>(dx);
} }
if (ddout) { if (ddout) {
dev_ctx.template Alloc<T>(ddout); dev_ctx.template Alloc<T>(ddout);
} }
phi::funcs::SinDoubleGradFunctor<T> functor; phi::funcs::SinDoubleGradFunctor<T> functor;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout); functor(dev_ctx, &x, dout.get_ptr(), &ddx, dx, ddout);
} }
template <typename T, typename Context> template <typename T, typename Context>
void SinTripleGradKernel(const Context& dev_ctx, void SinTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx, const paddle::optional<DenseTensor>& ddx,
const DenseTensor& d_dx_new, const DenseTensor& d_dx_new,
const DenseTensor& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_x_new, DenseTensor* d_x_new,
DenseTensor* d_dout, DenseTensor* d_dout,
DenseTensor* d_ddx) { DenseTensor* d_ddx) {
if (d_dout) { if (d_dout) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_dout); dev_ctx.template Alloc<T>(d_dout);
} }
if (d_x_new) { if (d_x_new) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_x_new); dev_ctx.template Alloc<T>(d_x_new);
} }
if (d_ddx) { if (d_ddx) {
d_dout->Resize(ddx.dims());
dev_ctx.template Alloc<T>(d_ddx); dev_ctx.template Alloc<T>(d_ddx);
} }
funcs::SinTripleGradFunctor<T> functor; funcs::SinTripleGradFunctor<T> functor;
functor(dev_ctx, functor(dev_ctx,
&x, &x,
&ddx, ddx.get_ptr(),
&dout, dout.get_ptr(),
&d_ddout, d_ddout.get_ptr(),
&d_dx_new, // input &d_dx_new, // input
d_dout, d_dout,
d_x_new, d_x_new,
...@@ -649,49 +645,45 @@ void SinTripleGradKernel(const Context& dev_ctx, ...@@ -649,49 +645,45 @@ void SinTripleGradKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void CosDoubleGradKernel(const Context& dev_ctx, void CosDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx, const DenseTensor& ddx,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* ddout) { DenseTensor* ddout) {
if (dx) { if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx); dev_ctx.template Alloc<T>(dx);
} }
if (ddout) { if (ddout) {
dev_ctx.template Alloc<T>(ddout); dev_ctx.template Alloc<T>(ddout);
} }
phi::funcs::CosDoubleGradFunctor<T> functor; phi::funcs::CosDoubleGradFunctor<T> functor;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout); functor(dev_ctx, &x, dout.get_ptr(), &ddx, dx, ddout);
} }
template <typename T, typename Context> template <typename T, typename Context>
void CosTripleGradKernel(const Context& dev_ctx, void CosTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx, const paddle::optional<DenseTensor>& ddx,
const DenseTensor& d_dx_new, const DenseTensor& d_dx_new,
const DenseTensor& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_x_new, DenseTensor* d_x_new,
DenseTensor* d_dout, DenseTensor* d_dout,
DenseTensor* d_ddx) { DenseTensor* d_ddx) {
if (d_dout) { if (d_dout) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_dout); dev_ctx.template Alloc<T>(d_dout);
} }
if (d_x_new) { if (d_x_new) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_x_new); dev_ctx.template Alloc<T>(d_x_new);
} }
if (d_ddx) { if (d_ddx) {
d_dout->Resize(ddx.dims());
dev_ctx.template Alloc<T>(d_ddx); dev_ctx.template Alloc<T>(d_ddx);
} }
funcs::CosTripleGradFunctor<T> functor; funcs::CosTripleGradFunctor<T> functor;
functor(dev_ctx, functor(dev_ctx,
&x, &x,
&ddx, ddx.get_ptr(),
&dout, dout.get_ptr(),
&d_ddout, d_ddout.get_ptr(),
&d_dx_new, // input &d_dx_new, // input
d_dout, d_dout,
d_x_new, d_x_new,
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
...@@ -472,6 +473,7 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, ...@@ -472,6 +473,7 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, y, ddx_safe, ddout, axis); dev_ctx, y, ddx_safe, ddout, axis);
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
...@@ -483,12 +485,14 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, ...@@ -483,12 +485,14 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
ddout_t.device(place) = ddout_t + ddout_tmp_t; ddout_t.device(place) = ddout_t + ddout_tmp_t;
} else { } else {
// use dx to save memory, other than alloc tmp tensor // use dx to save memory, other than alloc tmp tensor
if (dx) {
DenseTensor* ddout_tmp = dx; DenseTensor* ddout_tmp = dx;
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, ddy_safe, ddout_tmp, axis); dev_ctx, x, ddy_safe, ddout_tmp, axis);
// NOTE: in the following ElemwiseGradCompute, for the // NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first // first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not // output tensor will not be activated, DivGradDx function will not
...@@ -505,6 +509,7 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, ...@@ -505,6 +509,7 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
dy, dy,
MulGradDX<T>(), MulGradDX<T>(),
MulGradDY<T>()); MulGradDY<T>());
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
...@@ -514,11 +519,36 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, ...@@ -514,11 +519,36 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
auto ddout_t = phi::EigenVector<T>::Flatten(*ddout); auto ddout_t = phi::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = phi::EigenVector<T>::Flatten(*ddout_tmp); auto ddout_tmp_t = phi::EigenVector<T>::Flatten(*ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t; ddout_t.device(place) = ddout_t + ddout_tmp_t;
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, ddy_safe, dx, axis); dev_ctx, dout, ddy_safe, dx, axis);
} else {
DenseTensor tmp_a(ddout->dtype());
tmp_a.Resize(ddout->dims());
dev_ctx.template Alloc<T>(&tmp_a);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, ddy_safe, &tmp_a, axis);
auto ddout_t1 = phi::EigenVector<T>::Flatten(tmp_a);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddx_safe, y, ddout, axis);
auto ddout_t2 = phi::EigenVector<T>::Flatten(*ddout);
ddout_t2.device(place) = ddout_t2 + ddout_t1;
}
} }
} else { } else {
if (dx && dy) { if (dx && dy) {
...@@ -544,8 +574,8 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, ...@@ -544,8 +574,8 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
const paddle::optional<DenseTensor>& ddx, const paddle::optional<DenseTensor>& ddx,
const paddle::optional<DenseTensor>& ddy, const paddle::optional<DenseTensor>& ddy,
const DenseTensor& d_dx, const paddle::optional<DenseTensor>& d_dx,
const DenseTensor& d_dy, const paddle::optional<DenseTensor>& d_dy,
const paddle::optional<DenseTensor>& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
int axis, int axis,
DenseTensor* d_x, DenseTensor* d_x,
...@@ -599,6 +629,13 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, ...@@ -599,6 +629,13 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddx_safe, *(d_ddout.get_ptr()), d_y, axis); dev_ctx, ddx_safe, *(d_ddout.get_ptr()), d_y, axis);
} }
} else {
if (d_x) {
FullLikeKernel<T, Context>(dev_ctx, x, Scalar(0.0), x.dtype(), d_x);
}
if (d_y) {
FullLikeKernel<T, Context>(dev_ctx, y, Scalar(0.0), y.dtype(), d_y);
}
} }
if (d_dout) { if (d_dout) {
...@@ -607,61 +644,135 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, ...@@ -607,61 +644,135 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
DenseTensor d_dout_tmp; DenseTensor d_dout_tmp;
d_dout_tmp.Resize(dout.dims()); d_dout_tmp.Resize(dout.dims());
dev_ctx.template Alloc<T>(&d_dout_tmp); dev_ctx.template Alloc<T>(&d_dout_tmp);
if (d_dy && d_dx) {
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, d_dy, ddx_safe, d_dout, axis); dev_ctx, d_dy.get(), ddx_safe, d_dout, axis);
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddy_safe, d_dx, &d_dout_tmp, axis); dev_ctx, ddy_safe, d_dx.get(), &d_dout_tmp, axis);
auto d_dout_t = phi::EigenVector<T>::Flatten(*d_dout); auto d_dout_t = phi::EigenVector<T>::Flatten(*d_dout);
auto d_dout_tmp_t = phi::EigenVector<T>::Flatten(d_dout_tmp); auto d_dout_tmp_t = phi::EigenVector<T>::Flatten(d_dout_tmp);
d_dout_t.device(place) = d_dout_t + d_dout_tmp_t; d_dout_t.device(place) = d_dout_t + d_dout_tmp_t;
} else if (d_dy && !d_dx) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, d_dy.get(), ddx_safe, d_dout, axis);
auto d_dout_t = phi::EigenVector<T>::Flatten(*d_dout);
d_dout_t.device(place) = d_dout_t;
} else if (!d_dy && d_dx) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddy_safe, d_dx.get(), d_dout, axis);
auto d_dout_t = phi::EigenVector<T>::Flatten(*d_dout);
d_dout_t.device(place) = d_dout_t;
} else {
FullLikeKernel<T, Context>(
dev_ctx, dout, Scalar(0.0), dout.dtype(), d_dout);
}
} }
if (d_ddx) { if (d_ddx && ddx) {
// get d_ddx // get d_ddx
// d_ddx = dout * d_dy + y * d_ddout // d_ddx = dout * d_dy + y * d_ddout
DenseTensor d_ddx_tmp; DenseTensor d_ddx_tmp;
d_ddx_tmp.Resize(ddx->dims()); d_ddx_tmp.Resize(ddx->dims());
dev_ctx.template Alloc<T>(&d_ddx_tmp); dev_ctx.template Alloc<T>(&d_ddx_tmp);
if (d_dy && d_ddout) {
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dy, d_ddx, axis); dev_ctx, dout, d_dy.get(), d_ddx, axis);
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis); dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis);
auto d_ddx_t = phi::EigenVector<T>::Flatten(*d_ddx); auto d_ddx_t = phi::EigenVector<T>::Flatten(*d_ddx);
auto d_ddx_tmp_t = phi::EigenVector<T>::Flatten(d_ddx_tmp); auto d_ddx_tmp_t = phi::EigenVector<T>::Flatten(d_ddx_tmp);
d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t; d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t;
} else if (d_dy && !d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dy.get(), d_ddx, axis);
auto d_ddx_t = phi::EigenVector<T>::Flatten(*d_ddx);
d_ddx_t.device(place) = d_ddx_t;
} else if (!d_dy && d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, y, *(d_ddout.get_ptr()), d_ddx, axis);
auto d_ddx_t = phi::EigenVector<T>::Flatten(*d_ddx);
d_ddx_t.device(place) = d_ddx_t;
} else {
FullLikeKernel<T, Context>(dev_ctx, x, Scalar(0.0), x.dtype(), d_ddx);
}
} }
if (d_ddy) { if (d_ddy && ddy) {
// get d_ddy // get d_ddy
// d_ddy = dout * d_dx + x * d_ddout // d_ddy = dout * d_dx + x * d_ddout
DenseTensor d_ddy_tmp; DenseTensor d_ddy_tmp;
d_ddy_tmp.Resize(ddy->dims()); d_ddy_tmp.Resize(ddy->dims());
dev_ctx.template Alloc<T>(&d_ddy_tmp); dev_ctx.template Alloc<T>(&d_ddy_tmp);
if (d_dx && d_ddout) {
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dx, d_ddy, axis); dev_ctx, dout, d_dx.get(), d_ddy, axis);
funcs::DefaultElementwiseOperator<Context, funcs::DefaultElementwiseOperator<Context,
T, T,
funcs::MultiplyFunctor<T>, funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>( funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis); dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis);
auto d_ddy_t = phi::EigenVector<T>::Flatten(*d_ddy); auto d_ddy_t = phi::EigenVector<T>::Flatten(*d_ddy);
auto d_ddy_tmp_t = phi::EigenVector<T>::Flatten(d_ddy_tmp); auto d_ddy_tmp_t = phi::EigenVector<T>::Flatten(d_ddy_tmp);
d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t; d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t;
} else if (d_dx && !d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dx.get(), d_ddy, axis);
auto d_ddy_t = phi::EigenVector<T>::Flatten(*d_ddy);
d_ddy_t.device(place) = d_ddy_t;
} else if (!d_dx && d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, *(d_ddout.get_ptr()), d_ddy, axis);
auto d_ddy_t = phi::EigenVector<T>::Flatten(*d_ddy);
d_ddy_t.device(place) = d_ddy_t;
} else {
FullLikeKernel<T, Context>(dev_ctx, y, Scalar(0.0), y.dtype(), d_ddy);
}
} }
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <limits> #include <limits>
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
......
...@@ -28,4 +28,4 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, ...@@ -28,4 +28,4 @@ void LogcumsumexpGradKernel(const Context& dev_ctx,
bool exclusive, bool exclusive,
bool reverse, bool reverse,
DenseTensor* d_x); DenseTensor* d_x);
} } // namespace phi
...@@ -47,8 +47,8 @@ void MatmulTripleGradKernel(const Context& dev_ctx, ...@@ -47,8 +47,8 @@ void MatmulTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
const DenseTensor& ddx, const paddle::optional<DenseTensor>& ddx,
const DenseTensor& ddy, const paddle::optional<DenseTensor>& ddy,
const paddle::optional<DenseTensor>& d_dx, const paddle::optional<DenseTensor>& d_dx,
const paddle::optional<DenseTensor>& d_dy, const paddle::optional<DenseTensor>& d_dy,
const paddle::optional<DenseTensor>& d_ddout, const paddle::optional<DenseTensor>& d_ddout,
......
...@@ -140,6 +140,8 @@ PD_REGISTER_KERNEL(full_like, ...@@ -140,6 +140,8 @@ PD_REGISTER_KERNEL(full_like,
float, float,
int, int,
int64_t, int64_t,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册