未验证 提交 d1e93be1 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Optimize Grad by prune useless branch (#47827)

* [Eager] Fix paddle.grad interface

* [Eager] Support minimum SubGraph for GeneralGrad

* Add needed_nodes to prune grad graph more thoroughly

* [Eager] Add grad_node_trans_mapping_ to record which grad_node has been transformed to AccumulationNode

* [Eager] Fix paddle.grad interface

* Polish code

* remove potential_stop_node

* Add endding_nodes to enhance genSugraph logic

* clear endding_nodes_

* polish code

* rename endding_nodes to endding_nades_

* Refactor grad interface

* Add register_hook case to fix coverage-ci

* Fix code format

* Refactor general_grad

* Add more code comments

* call clear directly to release GradSlotMeta

* fix a mistake

* fix matmul/ multiply kernel logic and optional input in yaml, fill zeros logic and so on.

* fix batch_norm_double_grad yaml optional config

* fix tanh_triple_grad yaml and kernels

* fix MultiplyTripleGradKernel optional logic

* fix merge mistake

* fix compile error

* remove legacy attr for bn

* polish code

* fix some kernel

* merge develop

* fix error

* remote log

* fix kernel with full like

* hide value log behind

* hide value log behind

* fix matmul_triple grad
Co-authored-by: NWeilong Wu <veyron_wu@163.com>
上级 0754e09d
......@@ -35,6 +35,7 @@ ops_to_fill_zero_for_empty_grads = set(
"multiply_triple_grad",
"conv2d_grad_grad",
"batch_norm_double_grad",
"tanh_grad",
"tanh_double_grad",
"tanh_triple_grad",
"sin_double_grad",
......
......@@ -230,7 +230,7 @@ FORWARD_FUNCTION_TEMPLATE = """
AFTER_LOG_PRINT_TEMPLATE = """
if(VLOG_IS_ON(4)){{
const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s], Output: [%s] }} \";
const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s], \\n Output: [%s] }} \";
{}
VLOG(4) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str, output_str);
}}
......
......@@ -173,9 +173,10 @@ std::vector<paddle::experimental::Tensor> RunBackward(
node_input_buffers_dict[grad_node] =
std::make_unique<GradTensorHolder>(grad_node->InputMeta());
}
bool copy_from_grad_t =
grad_tensors.size() > 0 && grad_tensors[i].initialized();
if (copy_from_grad_t) {
// copy grad tensor since we should totally run grad without affect forward
// value
if (grad_tensors.size() > 0 && grad_tensors[i].initialized()) {
PADDLE_ENFORCE(
grad_tensors.size() == tensors.size(),
paddle::platform::errors::Fatal(
......@@ -357,16 +358,6 @@ std::vector<paddle::experimental::Tensor> RunBackward(
"Node's in-degree cannot be negative.",
next_node->name()));
if (is_general_grad) {
if (node_in_degree_map[next_node] == 0 &&
GeneralGrad::Instance().IsNeededNodes(next_node)) {
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
} else {
queue.push_back(std::move(next_node));
}
}
} else {
if (node_in_degree_map[next_node] == 0) {
if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
queue.push_front(std::move(next_node));
......@@ -377,7 +368,6 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
}
}
}
VLOG(7) << "Run Backward Final hook size: "
<< egr::Controller::Instance().FinalBackwardHooks().size();
......
......@@ -51,6 +51,10 @@ class GeneralGrad {
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
PADDLE_ENFORCE_NOT_NULL(
auto_grad_meta,
paddle::platform::errors::Fatal(
"We got %s:[%d] 's autograd meta is NULL.", msg, i));
auto* target_node = auto_grad_meta->GetMutableGradNode().get();
if (orig_to_copied_node_map_.count(target_node)) {
......@@ -82,10 +86,13 @@ class GeneralGrad {
// input_target_nodes
void PurifyPotentialStartUpNodes() {
VLOG(6) << "Running in PurifyPotentialStartUpNodes";
if (input_target_nodes_inputmeta_map_.empty()) return;
if (input_target_nodes_inputmeta_map_.empty()) {
VLOG(6) << "No input target nodes found, skip.";
return;
}
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : potential_startup_nodes_) {
auto iter = input_target_nodes_inputmeta_map_.find(startup_op);
for (auto startup_node : potential_startup_nodes_) {
auto iter = input_target_nodes_inputmeta_map_.find(startup_node);
if (iter != input_target_nodes_inputmeta_map_.end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first);
}
......@@ -157,11 +164,11 @@ class GeneralGrad {
potential_startup_nodes_.erase(node);
}
}
}
} // TODO(jiabin): May we need some check here.
}
// Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes_, potential_startup_nodes_
// record depending_nodes_
void GetGraphInfoBetweenTargets(const std::deque<GradNodeBase*>& init_queue) {
VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
......@@ -227,7 +234,7 @@ class GeneralGrad {
std::make_shared<paddle::experimental::Tensor>(target_result);
}
}
}
} // TODO(jiabin): Some check here.
}
void SetResultForEnddingNodes(
......@@ -319,21 +326,22 @@ class GeneralGrad {
void SetNodeToAccumulationNode(GradNodeBase* node) {
if (dynamic_cast<egr::GradNodeAccumulation*>(node)) return;
if (!(depending_nodes_)[node].empty()) {
// Find precedding_nodes of current node.
auto precedding_nodes = (depending_nodes_)[node];
for (auto pre_nodes : precedding_nodes) {
paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
pre_nodes_edges = pre_nodes->MutableOutputMeta();
for (size_t i = 0; i < pre_nodes_edges.size(); i++) {
for (size_t j = 0; j < pre_nodes_edges[i].size(); j++) {
auto edge_ = pre_nodes_edges[i][j].GetEdge();
const auto& edge_ = pre_nodes_edges[i][j].GetEdge();
if (edge_.GetGradNode() == node) {
auto autograd_meta = egr::AutogradMeta(edge_);
Edge& pre_node_edge = pre_nodes_edges[i][j].GetMutableEdge();
if (copied_node_to_endding_node_map_.count(node)) {
pre_node_edge.SetGradNode(
copied_node_to_endding_node_map_[node]);
} else {
auto autograd_meta = egr::AutogradMeta(edge_);
std::shared_ptr<GradNodeBase> shared_grad_node_accumulation =
std::make_shared<egr::GradNodeAccumulation>(&autograd_meta);
pre_node_edge.SetGradNode(shared_grad_node_accumulation);
......@@ -361,7 +369,7 @@ class GeneralGrad {
grad_node->SetGradientHookFuntions(
node->GetGradientHookFuntions());
}
}
} // or this node has no need to change
}
}
}
......@@ -381,12 +389,10 @@ class GeneralGrad {
}
visited.insert(node);
if (IsInputTargetNodes(node)) {
if (IsEnddingNodes(node)) {
if (IsInputTargetNodes(node) && IsEnddingNodes(node)) {
SetNodeToAccumulationNode(node);
continue;
}
}
paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
meta = node->MutableOutputMeta();
......@@ -411,7 +417,17 @@ class GeneralGrad {
continue;
}
// TODO(weilong): support prune logic deeper
if (meta.size() != 1 && IsNeededNodes(node) &&
!IsNeededNodes(next_node.get()) && !IsEnddingNodes(node)) {
VLOG(3) << "Get stop edge from grad_node: " << node->name() << " : "
<< node << " to:" << next_node->name() << ", "
<< next_node.get() << " with output rank info: " << i
<< ", " << j;
// No need to compute grad from needed Nodes to no need Nodes
meta[i][j].SetStopGradient(true);
edge.Clear();
continue;
}
// Update BFS queue
queue_.push_back(next_node.get());
......@@ -502,7 +518,8 @@ class GeneralGrad {
// Save node and update mapping
orig_to_copied_node_map_[orig_node.get()] = copied_node;
copied_grad_nodes_.push_back(copied_node);
VLOG(3) << "Copied Node: " << orig_node->name() << " ptr: " << orig_node
<< " to ptr: " << copied_node;
return copied_node.get();
}
......
......@@ -99,6 +99,11 @@ void GradTensorHolder::add(size_t slot_id,
size_t rank,
const paddle::experimental::Tensor& t,
bool create_graph) {
if (!t.initialized()) {
VLOG(3) << "No need to do accumulate for uninitialized t.";
return;
} // TODO(jiabin): Remove this when we fix all kernel.
PADDLE_ENFORCE(slot_id < buffer_.size(),
paddle::platform::errors::Fatal(
"Invalid slot_id for GradTensorHolder::add() "
......
......@@ -277,7 +277,58 @@ class EagerUtils {
} else {
tensor_info_str += "Unknown";
}
if (VLOG_IS_ON(6)) {
if (VLOG_IS_ON(11)) {
const char* TENSOR_PRINT_TEMPLATE =
"{Name: %s, Initialized: %d, Ptr: %d "
"TensorInfo: [ %s ], Value:[ %s ], ADInfo:[ %s ]}";
auto* ad_meta = nullable_autograd_meta(t);
if (ad_meta && (ad_meta->WeakGrad().lock().get())) {
std::string ad_info_str = "";
const char* AD_INFO_TEMPLATE =
"Grad: [ %s ], GradNode: [ %s ], StopGradient: [ %d ]";
ad_info_str += paddle::string::Sprintf(AD_INFO_TEMPLATE,
TensorStr(ad_meta->Grad()),
GradNodeStr(t),
ad_meta->StopGradient());
auto* data_ptr = dynamic_cast<phi::DenseTensor*>(t.impl().get());
if (t.is_initialized() && data_ptr) {
return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE,
tensor_name_str,
t.initialized(),
t.impl(),
tensor_info_str,
*data_ptr,
ad_info_str);
} else {
return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE,
tensor_name_str,
t.initialized(),
t.impl(),
tensor_info_str,
"None",
ad_info_str);
}
} else {
auto* data_ptr = dynamic_cast<phi::DenseTensor*>(t.impl().get());
if (t.is_initialized() && data_ptr) {
return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE,
tensor_name_str,
t.initialized(),
t.impl(),
tensor_info_str,
*data_ptr,
"None");
} else {
return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE,
tensor_name_str,
t.initialized(),
t.impl(),
tensor_info_str,
"None",
"None");
}
}
} else if (VLOG_IS_ON(6)) {
const char* TENSOR_PRINT_TEMPLATE =
"{Name: %s, Initialized: %d, Ptr: %d "
"TensorInfo: [ %s ], ADInfo:[ %s ]}";
......
......@@ -187,6 +187,7 @@
param : [x, x]
kernel :
func : cos_double_grad
optional: grad_out
backward : cos_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
......@@ -211,6 +212,7 @@
param : [x, x, grad_x_grad_forward]
kernel :
func : cos_triple_grad
optional: grad_out_forward, grad_x_grad_forward, grad_out_grad_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_op : cosh_grad
......@@ -872,6 +874,7 @@
param : [x, x]
kernel :
func : sin_double_grad
optional: grad_out
backward : sin_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
......@@ -896,6 +899,7 @@
param : [x, x, grad_x_grad_forward]
kernel :
func : sin_triple_grad
optional: grad_out_forward, grad_x_grad_forward, grad_out_grad_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_op : sinh_grad
......@@ -1054,6 +1058,7 @@
kernel :
func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
optional : grad_out_new_grad, grad_out_grad_grad
- backward_op : thresholded_relu_grad
forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out)
......
......@@ -124,7 +124,7 @@
kernel :
func : batch_norm_grad_grad
data_type : x
optional : out_mean, out_variance
optional : out_mean, out_variance, grad_x_grad, grad_scale_grad, grad_bias_grad
inplace : (grad_out -> grad_out_grad)
- backward_op : batch_norm_grad
......@@ -856,7 +856,7 @@
param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel :
func : matmul_triple_grad
optional : grad_x_grad, grad_y_grad, grad_grad_out_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad
- backward_op : max_grad
forward: max (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out)
......@@ -1024,10 +1024,10 @@
output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad)
infer_meta :
func : GeneralQuinaryGradInferMeta
param : [x, y, fwd_grad_out, x, y]
param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel :
func : multiply_triple_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_grad_out_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad
- backward_op : nearest_interp_grad
forward : nearest_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) -> Tensor(output)
......
......@@ -83,7 +83,7 @@ void ReluDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SinDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout);
......@@ -91,7 +91,7 @@ void SinDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void CosDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout);
......@@ -109,8 +109,8 @@ void TanhTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
const paddle::optional<DenseTensor>& d_dout_new,
const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx);
......@@ -118,10 +118,10 @@ void TanhTripleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SinTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const paddle::optional<DenseTensor>& dout,
const paddle::optional<DenseTensor>& ddx,
const DenseTensor& d_dx_new,
const DenseTensor& d_ddout,
const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_x_new,
DenseTensor* d_dout,
DenseTensor* d_ddx);
......@@ -129,10 +129,10 @@ void SinTripleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void CosTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const paddle::optional<DenseTensor>& dout,
const paddle::optional<DenseTensor>& ddx,
const DenseTensor& d_dx_new,
const DenseTensor& d_ddout,
const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_x_new,
DenseTensor* d_dout,
DenseTensor* d_ddx);
......
......@@ -64,7 +64,8 @@ void BatchNormGradKernel(const Context& dev_ctx,
DenseTensor* bias_grad);
template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context& dev_ctx,
void BatchNormDoubleGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const paddle::optional<DenseTensor>& mean,
......@@ -72,17 +73,17 @@ void BatchNormDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const DenseTensor& y_grad,
const DenseTensor& x_grad_grad,
const DenseTensor& scale_grad_grad,
const DenseTensor& bias_grad_grad,
const paddle::optional<DenseTensor>& x_grad_grad,
const paddle::optional<DenseTensor>& scale_grad_grad,
const paddle::optional<DenseTensor>& bias_grad_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* y_grad_grad);
} // namespace phi
......@@ -334,7 +334,8 @@ void BatchNormGradKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context& ctx,
void BatchNormDoubleGradKernel(
const Context& ctx,
const DenseTensor& x,
const DenseTensor& scale,
const paddle::optional<DenseTensor>& mean,
......@@ -342,9 +343,9 @@ void BatchNormDoubleGradKernel(const Context& ctx,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const DenseTensor& y_grad,
const DenseTensor& x_grad_grad,
const DenseTensor& scale_grad_grad,
const DenseTensor& bias_grad_grad,
const paddle::optional<DenseTensor>& x_grad_grad,
const paddle::optional<DenseTensor>& scale_grad_grad,
const paddle::optional<DenseTensor>& bias_grad_grad,
float momentum,
float epsilon,
const std::string& data_layout_str,
......@@ -369,9 +370,9 @@ void BatchNormDoubleGradKernel(const Context& ctx,
const auto data_layout = phi::StringToDataLayout(data_layout_str);
const auto* ddX = &x_grad_grad;
const auto* ddScale = &scale_grad_grad;
const auto* ddBias = &bias_grad_grad;
const auto* ddX = x_grad_grad.get_ptr();
const auto* ddScale = scale_grad_grad.get_ptr();
const auto* ddBias = bias_grad_grad.get_ptr();
auto* dX = x_grad;
auto* dScale = scale_grad;
......
......@@ -108,6 +108,9 @@ PD_REGISTER_KERNEL(full_like,
int,
int64_t,
bool,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -47,8 +47,8 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
const paddle::optional<DenseTensor>& ddx,
const paddle::optional<DenseTensor>& ddy,
const DenseTensor& d_dx,
const DenseTensor& d_dy,
const paddle::optional<DenseTensor>& d_dx,
const paddle::optional<DenseTensor>& d_dy,
const paddle::optional<DenseTensor>& d_ddout,
int axis,
DenseTensor* d_x,
......
......@@ -125,15 +125,24 @@ struct SinDoubleGradFunctor : public BaseActivationFunctor<T> {
// calculate d2x first, so d2d1y can inplace d2d1x
auto d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "d2x", "SinDoubleGrad"));
if (dX) {
if (dOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "d1y", "SinDoubleGrad"));
d2x.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()) * d1y;
} else {
d2x.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()) * static_cast<T>(0);
}
}
// calculate d2d1y
if (ddOut) {
auto d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "SinDoubleGrad"));
d2d1y.device(*d) = d2d1x * x.unaryExpr(Cosine<T>());
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
......@@ -167,28 +176,71 @@ struct SinTripleGradFunctor : public BaseActivationFunctor<T> {
auto* d = dev.eigen_device();
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "x", "SinTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "SinTripleGrad"));
if (d_x_New) {
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "SinTripleGrad"));
if (dOut && ddX && d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "SinTripleGrad"));
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "SinTripleGrad"));
d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d2d1x * d3d2x -
x.unaryExpr(Sine<T>()) * d2d1x * d3d2d1y;
} else if (!dOut && ddX && d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
d3x.device(*d) = -x.unaryExpr(Sine<T>()) * d2d1x * d3d2d1y;
} else if (dOut && ddX && !d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d2d1x * d3d2x;
} else {
d3x.device(*d) = x * static_cast<T>(0);
}
}
if (d_d_Out) {
auto d3d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "SinTripleGrad"));
if (ddX) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
d3d1y.device(*d) = -x.unaryExpr(Sine<T>()) * d2d1x * d3d2x;
} else {
d3d1y.device(*d) = static_cast<T>(0) * x;
}
}
if (d_DDx) {
auto d3d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "SinTripleGrad"));
if (dOut && d_DDOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d1y * d3d2x +
x.unaryExpr(Cosine<T>()) * d3d2d1y;
} else if (dOut && !d_DDOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d1y * d3d2x;
} else if (!dOut && d_DDOut) {
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
d3d2d1x.device(*d) = x.unaryExpr(Cosine<T>()) * d3d2d1y;
} else {
d3d2d1x.device(*d) = x * static_cast<T>(0);
}
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
......@@ -270,15 +322,23 @@ struct CosDoubleGradFunctor : public BaseActivationFunctor<T> {
// calculate d2x first, so d2d1y can inplace d2d1x
auto d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "d2x", "CosDoubleGrad"));
if (ddOut) {
if (dOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "d1y", "CosDoubleGrad"));
d2x.device(*d) = -d2d1x * x.unaryExpr(Cosine<T>()) * d1y;
} else {
d2x.device(*d) = x * static_cast<T>(0);
}
}
if (dX) {
// calculate d2d1y
auto d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "CosDoubleGrad"));
d2d1y.device(*d) = -d2d1x * x.unaryExpr(Sine<T>());
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
......@@ -297,28 +357,72 @@ struct CosTripleGradFunctor : public BaseActivationFunctor<T> {
auto* d = dev.eigen_device();
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "x", "CosTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "CosTripleGrad"));
if (d_x_New) {
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad"));
if (dOut && ddX && d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "CosTripleGrad"));
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad"));
d3x.device(*d) = x.unaryExpr(Sine<T>()) * d1y * d2d1x * d3d2x -
x.unaryExpr(Cosine<T>()) * d2d1x * d3d2d1y;
} else if (dOut && ddX && !d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
d3x.device(*d) = x.unaryExpr(Sine<T>()) * d1y * d2d1x * d3d2x;
} else if (!dOut && ddX && d_DDOut) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d2d1x * d3d2d1y;
} else {
d3x.device(*d) = static_cast<T>(0) * x;
}
}
if (d_d_Out) {
auto d3d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "CosTripleGrad"));
if (ddX) {
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
d3d1y.device(*d) = -x.unaryExpr(Cosine<T>()) * d2d1x * d3d2x;
} else {
d3d1y.device(*d) = static_cast<T>(0) * x;
}
}
if (d_DDx) {
auto d3d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "CosTripleGrad"));
if (dOut && d_DDOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d3d2x -
x.unaryExpr(Sine<T>()) * d3d2d1y;
} else if (!dOut && d_DDOut) {
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d3d2d1y;
} else if (dOut && !d_DDOut) {
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d3d2x;
} else {
d3d2d1x.device(*d) = static_cast<T>(0) * x;
}
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
......@@ -1106,27 +1210,70 @@ struct TanhTripleGradFunctor : public BaseActivationFunctor<T> {
GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad"));
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
auto d_dOutNew = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
if (d_Out_New) {
auto d_OutNew = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad"));
if (d_DDOut && d_dOut_New) {
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_OutNew.device(*d) = (static_cast<T>(-2) * out * ddx * d_ddOut) -
(static_cast<T>(2) * dout * ddx * d_dOutNew);
} else if (d_DDOut && !d_dOut_New) {
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
d_OutNew.device(*d) = (static_cast<T>(-2) * out * ddx * d_ddOut);
} else if (!d_DDOut && d_dOut_New) {
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_OutNew.device(*d) = -(static_cast<T>(2) * dout * ddx * d_dOutNew);
} else {
d_OutNew.device(*d) = static_cast<T>(0) * out;
}
}
if (d_d_Out) {
auto d_dOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad"));
if (d_dOut_New) {
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_dOut.device(*d) = static_cast<T>(-2) * out * ddx * d_dOutNew;
} else {
d_dOut.device(*d) = static_cast<T>(0) * out;
}
}
if (d_DDx) {
auto d_ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad"));
if (d_DDOut && d_dOut_New) {
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_ddx.device(*d) = (static_cast<T>(1) - (out * out)) * d_ddOut -
static_cast<T>(2) * out * dout * d_dOutNew;
} else if (d_DDOut && !d_dOut_New) {
auto d_ddOut = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
d_ddx.device(*d) = (static_cast<T>(1) - (out * out)) * d_ddOut;
} else if (!d_DDOut && d_dOut_New) {
auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
d_ddx.device(*d) = -static_cast<T>(2) * out * dout * d_dOutNew;
} else {
d_ddx.device(*d) = static_cast<T>(0) * ddx;
}
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
......
......@@ -1295,7 +1295,8 @@ void BatchNormGradKernel(const Context &dev_ctx,
}
template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context &ctx,
void BatchNormDoubleGradKernel(
const Context &ctx,
const DenseTensor &x,
const DenseTensor &scale,
const paddle::optional<DenseTensor> &mean,
......@@ -1303,9 +1304,9 @@ void BatchNormDoubleGradKernel(const Context &ctx,
const DenseTensor &saved_mean,
const DenseTensor &saved_variance,
const DenseTensor &y_grad,
const DenseTensor &x_grad_grad,
const DenseTensor &scale_grad_grad,
const DenseTensor &bias_grad_grad,
const paddle::optional<DenseTensor> &x_grad_grad,
const paddle::optional<DenseTensor> &scale_grad_grad,
const paddle::optional<DenseTensor> &bias_grad_grad,
float momentum,
float epsilon,
const std::string &data_layout_str,
......@@ -1330,7 +1331,8 @@ void BatchNormDoubleGradKernel(const Context &ctx,
running_mean = mean.get_ptr();
running_variance = variance.get_ptr();
}
paddle::operators::NormDoubleGradFunctor<Context, T>(ctx,
paddle::operators::NormDoubleGradFunctor<Context, T>(
ctx,
data_layout,
&x,
&scale,
......@@ -1341,9 +1343,9 @@ void BatchNormDoubleGradKernel(const Context &ctx,
running_variance,
epsilon,
use_global_stats,
&x_grad_grad,
&scale_grad_grad,
&bias_grad_grad,
x_grad_grad.get_ptr(),
scale_grad_grad.get_ptr(),
bias_grad_grad.get_ptr(),
x_grad,
scale_grad,
y_grad_grad);
......
......@@ -138,6 +138,8 @@ PD_REGISTER_KERNEL(full_like,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -177,8 +177,8 @@ void TanhTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
const paddle::optional<DenseTensor>& d_dout_new,
const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx) {
......@@ -199,8 +199,8 @@ void TanhTripleGradKernel(const Context& dev_ctx,
&out,
&ddx,
&dout,
&d_ddout,
&d_dout_new, // input
d_ddout.get_ptr(),
d_dout_new.get_ptr(), // input
d_dout,
d_out_new,
d_ddx); // output
......@@ -597,49 +597,45 @@ void SquareDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SinDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::SinDoubleGradFunctor<T> functor;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
functor(dev_ctx, &x, dout.get_ptr(), &ddx, dx, ddout);
}
template <typename T, typename Context>
void SinTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const paddle::optional<DenseTensor>& dout,
const paddle::optional<DenseTensor>& ddx,
const DenseTensor& d_dx_new,
const DenseTensor& d_ddout,
const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_x_new,
DenseTensor* d_dout,
DenseTensor* d_ddx) {
if (d_dout) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_dout);
}
if (d_x_new) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_x_new);
}
if (d_ddx) {
d_dout->Resize(ddx.dims());
dev_ctx.template Alloc<T>(d_ddx);
}
funcs::SinTripleGradFunctor<T> functor;
functor(dev_ctx,
&x,
&ddx,
&dout,
&d_ddout,
ddx.get_ptr(),
dout.get_ptr(),
d_ddout.get_ptr(),
&d_dx_new, // input
d_dout,
d_x_new,
......@@ -649,49 +645,45 @@ void SinTripleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void CosDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const paddle::optional<DenseTensor>& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::CosDoubleGradFunctor<T> functor;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
functor(dev_ctx, &x, dout.get_ptr(), &ddx, dx, ddout);
}
template <typename T, typename Context>
void CosTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const paddle::optional<DenseTensor>& dout,
const paddle::optional<DenseTensor>& ddx,
const DenseTensor& d_dx_new,
const DenseTensor& d_ddout,
const paddle::optional<DenseTensor>& d_ddout,
DenseTensor* d_x_new,
DenseTensor* d_dout,
DenseTensor* d_ddx) {
if (d_dout) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_dout);
}
if (d_x_new) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_x_new);
}
if (d_ddx) {
d_dout->Resize(ddx.dims());
dev_ctx.template Alloc<T>(d_ddx);
}
funcs::CosTripleGradFunctor<T> functor;
functor(dev_ctx,
&x,
&ddx,
&dout,
&d_ddout,
ddx.get_ptr(),
dout.get_ptr(),
d_ddout.get_ptr(),
&d_dx_new, // input
d_dout,
d_x_new,
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
......@@ -472,6 +473,7 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, y, ddx_safe, ddout, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
......@@ -483,12 +485,14 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
ddout_t.device(place) = ddout_t + ddout_tmp_t;
} else {
// use dx to save memory, other than alloc tmp tensor
if (dx) {
DenseTensor* ddout_tmp = dx;
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, ddy_safe, ddout_tmp, axis);
// NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
......@@ -505,6 +509,7 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
dy,
MulGradDX<T>(),
MulGradDY<T>());
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
......@@ -514,11 +519,36 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
auto ddout_t = phi::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = phi::EigenVector<T>::Flatten(*ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, ddy_safe, dx, axis);
} else {
DenseTensor tmp_a(ddout->dtype());
tmp_a.Resize(ddout->dims());
dev_ctx.template Alloc<T>(&tmp_a);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, ddy_safe, &tmp_a, axis);
auto ddout_t1 = phi::EigenVector<T>::Flatten(tmp_a);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddx_safe, y, ddout, axis);
auto ddout_t2 = phi::EigenVector<T>::Flatten(*ddout);
ddout_t2.device(place) = ddout_t2 + ddout_t1;
}
}
} else {
if (dx && dy) {
......@@ -544,8 +574,8 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
const paddle::optional<DenseTensor>& ddx,
const paddle::optional<DenseTensor>& ddy,
const DenseTensor& d_dx,
const DenseTensor& d_dy,
const paddle::optional<DenseTensor>& d_dx,
const paddle::optional<DenseTensor>& d_dy,
const paddle::optional<DenseTensor>& d_ddout,
int axis,
DenseTensor* d_x,
......@@ -599,6 +629,13 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddx_safe, *(d_ddout.get_ptr()), d_y, axis);
}
} else {
if (d_x) {
FullLikeKernel<T, Context>(dev_ctx, x, Scalar(0.0), x.dtype(), d_x);
}
if (d_y) {
FullLikeKernel<T, Context>(dev_ctx, y, Scalar(0.0), y.dtype(), d_y);
}
}
if (d_dout) {
......@@ -607,61 +644,135 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
DenseTensor d_dout_tmp;
d_dout_tmp.Resize(dout.dims());
dev_ctx.template Alloc<T>(&d_dout_tmp);
if (d_dy && d_dx) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, d_dy, ddx_safe, d_dout, axis);
dev_ctx, d_dy.get(), ddx_safe, d_dout, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddy_safe, d_dx, &d_dout_tmp, axis);
dev_ctx, ddy_safe, d_dx.get(), &d_dout_tmp, axis);
auto d_dout_t = phi::EigenVector<T>::Flatten(*d_dout);
auto d_dout_tmp_t = phi::EigenVector<T>::Flatten(d_dout_tmp);
d_dout_t.device(place) = d_dout_t + d_dout_tmp_t;
} else if (d_dy && !d_dx) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, d_dy.get(), ddx_safe, d_dout, axis);
auto d_dout_t = phi::EigenVector<T>::Flatten(*d_dout);
d_dout_t.device(place) = d_dout_t;
} else if (!d_dy && d_dx) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddy_safe, d_dx.get(), d_dout, axis);
auto d_dout_t = phi::EigenVector<T>::Flatten(*d_dout);
d_dout_t.device(place) = d_dout_t;
} else {
FullLikeKernel<T, Context>(
dev_ctx, dout, Scalar(0.0), dout.dtype(), d_dout);
}
}
if (d_ddx) {
if (d_ddx && ddx) {
// get d_ddx
// d_ddx = dout * d_dy + y * d_ddout
DenseTensor d_ddx_tmp;
d_ddx_tmp.Resize(ddx->dims());
dev_ctx.template Alloc<T>(&d_ddx_tmp);
if (d_dy && d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dy, d_ddx, axis);
dev_ctx, dout, d_dy.get(), d_ddx, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis);
auto d_ddx_t = phi::EigenVector<T>::Flatten(*d_ddx);
auto d_ddx_tmp_t = phi::EigenVector<T>::Flatten(d_ddx_tmp);
d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t;
} else if (d_dy && !d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dy.get(), d_ddx, axis);
auto d_ddx_t = phi::EigenVector<T>::Flatten(*d_ddx);
d_ddx_t.device(place) = d_ddx_t;
} else if (!d_dy && d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, y, *(d_ddout.get_ptr()), d_ddx, axis);
auto d_ddx_t = phi::EigenVector<T>::Flatten(*d_ddx);
d_ddx_t.device(place) = d_ddx_t;
} else {
FullLikeKernel<T, Context>(dev_ctx, x, Scalar(0.0), x.dtype(), d_ddx);
}
}
if (d_ddy) {
if (d_ddy && ddy) {
// get d_ddy
// d_ddy = dout * d_dx + x * d_ddout
DenseTensor d_ddy_tmp;
d_ddy_tmp.Resize(ddy->dims());
dev_ctx.template Alloc<T>(&d_ddy_tmp);
if (d_dx && d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dx, d_ddy, axis);
dev_ctx, dout, d_dx.get(), d_ddy, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis);
auto d_ddy_t = phi::EigenVector<T>::Flatten(*d_ddy);
auto d_ddy_tmp_t = phi::EigenVector<T>::Flatten(d_ddy_tmp);
d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t;
} else if (d_dx && !d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, d_dx.get(), d_ddy, axis);
auto d_ddy_t = phi::EigenVector<T>::Flatten(*d_ddy);
d_ddy_t.device(place) = d_ddy_t;
} else if (!d_dx && d_ddout) {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, *(d_ddout.get_ptr()), d_ddy, axis);
auto d_ddy_t = phi::EigenVector<T>::Flatten(*d_ddy);
d_ddy_t.device(place) = d_ddy_t;
} else {
FullLikeKernel<T, Context>(dev_ctx, y, Scalar(0.0), y.dtype(), d_ddy);
}
}
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <limits>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
......
......@@ -28,4 +28,4 @@ void LogcumsumexpGradKernel(const Context& dev_ctx,
bool exclusive,
bool reverse,
DenseTensor* d_x);
}
} // namespace phi
......@@ -47,8 +47,8 @@ void MatmulTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& ddy,
const paddle::optional<DenseTensor>& ddx,
const paddle::optional<DenseTensor>& ddy,
const paddle::optional<DenseTensor>& d_dx,
const paddle::optional<DenseTensor>& d_dy,
const paddle::optional<DenseTensor>& d_ddout,
......
......@@ -140,6 +140,8 @@ PD_REGISTER_KERNEL(full_like,
float,
int,
int64_t,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册