diff --git a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py index 748c9d1ad22f22e25e9854d00fca4790fab9b9da..8d03670a80773bae43e99f122d0e1cd789753640 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py @@ -35,6 +35,7 @@ ops_to_fill_zero_for_empty_grads = set( "multiply_triple_grad", "conv2d_grad_grad", "batch_norm_double_grad", + "tanh_grad", "tanh_double_grad", "tanh_triple_grad", "sin_double_grad", diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index e726ec8bd96706ae15cacc34087ba23fb579571e..b54f45363a00daa71b6de6ac87d9657fa0ff2c29 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -230,7 +230,7 @@ FORWARD_FUNCTION_TEMPLATE = """ AFTER_LOG_PRINT_TEMPLATE = """ if(VLOG_IS_ON(4)){{ - const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s], Output: [%s] }} \"; + const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s], \\n Output: [%s] }} \"; {} VLOG(4) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str, output_str); }} diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 20709d13166a7cdfa3a3360c981a9cb1e30dbc76..15c67f451be48d1cd9cad0404b2c5b1d4cbcabfe 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -173,9 +173,10 @@ std::vector RunBackward( node_input_buffers_dict[grad_node] = std::make_unique(grad_node->InputMeta()); } - bool copy_from_grad_t = - grad_tensors.size() > 0 && grad_tensors[i].initialized(); - if (copy_from_grad_t) { + + // copy grad tensor since we should totally run grad without affect forward + // value + if (grad_tensors.size() > 0 && grad_tensors[i].initialized()) { PADDLE_ENFORCE( grad_tensors.size() == tensors.size(), paddle::platform::errors::Fatal( @@ -357,22 +358,11 @@ std::vector RunBackward( "Node's in-degree cannot be negative.", next_node->name())); - if (is_general_grad) { - if (node_in_degree_map[next_node] == 0 && - GeneralGrad::Instance().IsNeededNodes(next_node)) { - if (dynamic_cast(next_node)) { - queue.push_front(std::move(next_node)); - } else { - queue.push_back(std::move(next_node)); - } - } - } else { - if (node_in_degree_map[next_node] == 0) { - if (dynamic_cast(next_node)) { - queue.push_front(std::move(next_node)); - } else { - queue.push_back(std::move(next_node)); - } + if (node_in_degree_map[next_node] == 0) { + if (dynamic_cast(next_node)) { + queue.push_front(std::move(next_node)); + } else { + queue.push_back(std::move(next_node)); } } } diff --git a/paddle/fluid/eager/general_grad.h b/paddle/fluid/eager/general_grad.h index 27f6a7e609a4ddfa31ed6b2fbcc458628bced4b8..142624a9d95642cb130e0258fc25e072a78967a6 100644 --- a/paddle/fluid/eager/general_grad.h +++ b/paddle/fluid/eager/general_grad.h @@ -51,6 +51,10 @@ class GeneralGrad { for (size_t i = 0; i < num_inputs; i++) { AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(inputs[i]); + PADDLE_ENFORCE_NOT_NULL( + auto_grad_meta, + paddle::platform::errors::Fatal( + "We got %s:[%d] 's autograd meta is NULL.", msg, i)); auto* target_node = auto_grad_meta->GetMutableGradNode().get(); if (orig_to_copied_node_map_.count(target_node)) { @@ -82,10 +86,13 @@ class GeneralGrad { // input_target_nodes void PurifyPotentialStartUpNodes() { VLOG(6) << "Running in PurifyPotentialStartUpNodes"; - if (input_target_nodes_inputmeta_map_.empty()) return; + if (input_target_nodes_inputmeta_map_.empty()) { + VLOG(6) << "No input target nodes found, skip."; + return; + } std::unordered_set potential_startup_nodes_to_be_erased; - for (auto startup_op : potential_startup_nodes_) { - auto iter = input_target_nodes_inputmeta_map_.find(startup_op); + for (auto startup_node : potential_startup_nodes_) { + auto iter = input_target_nodes_inputmeta_map_.find(startup_node); if (iter != input_target_nodes_inputmeta_map_.end()) { potential_startup_nodes_to_be_erased.emplace(iter->first); } @@ -157,11 +164,11 @@ class GeneralGrad { potential_startup_nodes_.erase(node); } } - } + } // TODO(jiabin): May we need some check here. } // Get Graph Info Betweent input target GradNode and outputs, - // record depending_nodes_, potential_startup_nodes_ + // record depending_nodes_ void GetGraphInfoBetweenTargets(const std::deque& init_queue) { VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; @@ -227,7 +234,7 @@ class GeneralGrad { std::make_shared(target_result); } } - } + } // TODO(jiabin): Some check here. } void SetResultForEnddingNodes( @@ -319,21 +326,22 @@ class GeneralGrad { void SetNodeToAccumulationNode(GradNodeBase* node) { if (dynamic_cast(node)) return; if (!(depending_nodes_)[node].empty()) { + // Find precedding_nodes of current node. auto precedding_nodes = (depending_nodes_)[node]; for (auto pre_nodes : precedding_nodes) { paddle::small_vector, kSlotSmallVectorSize>& pre_nodes_edges = pre_nodes->MutableOutputMeta(); for (size_t i = 0; i < pre_nodes_edges.size(); i++) { for (size_t j = 0; j < pre_nodes_edges[i].size(); j++) { - auto edge_ = pre_nodes_edges[i][j].GetEdge(); + const auto& edge_ = pre_nodes_edges[i][j].GetEdge(); if (edge_.GetGradNode() == node) { - auto autograd_meta = egr::AutogradMeta(edge_); Edge& pre_node_edge = pre_nodes_edges[i][j].GetMutableEdge(); if (copied_node_to_endding_node_map_.count(node)) { pre_node_edge.SetGradNode( copied_node_to_endding_node_map_[node]); } else { + auto autograd_meta = egr::AutogradMeta(edge_); std::shared_ptr shared_grad_node_accumulation = std::make_shared(&autograd_meta); pre_node_edge.SetGradNode(shared_grad_node_accumulation); @@ -361,7 +369,7 @@ class GeneralGrad { grad_node->SetGradientHookFuntions( node->GetGradientHookFuntions()); } - } + } // or this node has no need to change } } } @@ -381,11 +389,9 @@ class GeneralGrad { } visited.insert(node); - if (IsInputTargetNodes(node)) { - if (IsEnddingNodes(node)) { - SetNodeToAccumulationNode(node); - continue; - } + if (IsInputTargetNodes(node) && IsEnddingNodes(node)) { + SetNodeToAccumulationNode(node); + continue; } paddle::small_vector, kSlotSmallVectorSize>& @@ -411,7 +417,17 @@ class GeneralGrad { continue; } - // TODO(weilong): support prune logic deeper + if (meta.size() != 1 && IsNeededNodes(node) && + !IsNeededNodes(next_node.get()) && !IsEnddingNodes(node)) { + VLOG(3) << "Get stop edge from grad_node: " << node->name() << " : " + << node << " to:" << next_node->name() << ", " + << next_node.get() << " with output rank info: " << i + << ", " << j; + // No need to compute grad from needed Nodes to no need Nodes + meta[i][j].SetStopGradient(true); + edge.Clear(); + continue; + } // Update BFS queue queue_.push_back(next_node.get()); @@ -502,7 +518,8 @@ class GeneralGrad { // Save node and update mapping orig_to_copied_node_map_[orig_node.get()] = copied_node; copied_grad_nodes_.push_back(copied_node); - + VLOG(3) << "Copied Node: " << orig_node->name() << " ptr: " << orig_node + << " to ptr: " << copied_node; return copied_node.get(); } diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc index 14a8c26f9dcb8d2e2e06620f6b9fec14980cc226..56268924b50f3ceb29bcb8d7394f951976b60d4f 100644 --- a/paddle/fluid/eager/grad_tensor_holder.cc +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -99,6 +99,11 @@ void GradTensorHolder::add(size_t slot_id, size_t rank, const paddle::experimental::Tensor& t, bool create_graph) { + if (!t.initialized()) { + VLOG(3) << "No need to do accumulate for uninitialized t."; + return; + } // TODO(jiabin): Remove this when we fix all kernel. + PADDLE_ENFORCE(slot_id < buffer_.size(), paddle::platform::errors::Fatal( "Invalid slot_id for GradTensorHolder::add() " diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 7146261164900c279280b02dd127abc00c49dfd5..339f7af80364b284455ff7a5654c430b875e6ca3 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -277,7 +277,58 @@ class EagerUtils { } else { tensor_info_str += "Unknown"; } - if (VLOG_IS_ON(6)) { + if (VLOG_IS_ON(11)) { + const char* TENSOR_PRINT_TEMPLATE = + "{Name: %s, Initialized: %d, Ptr: %d " + "TensorInfo: [ %s ], Value:[ %s ], ADInfo:[ %s ]}"; + auto* ad_meta = nullable_autograd_meta(t); + if (ad_meta && (ad_meta->WeakGrad().lock().get())) { + std::string ad_info_str = ""; + const char* AD_INFO_TEMPLATE = + "Grad: [ %s ], GradNode: [ %s ], StopGradient: [ %d ]"; + ad_info_str += paddle::string::Sprintf(AD_INFO_TEMPLATE, + TensorStr(ad_meta->Grad()), + GradNodeStr(t), + ad_meta->StopGradient()); + auto* data_ptr = dynamic_cast(t.impl().get()); + if (t.is_initialized() && data_ptr) { + return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE, + tensor_name_str, + t.initialized(), + t.impl(), + tensor_info_str, + *data_ptr, + ad_info_str); + } else { + return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE, + tensor_name_str, + t.initialized(), + t.impl(), + tensor_info_str, + "None", + ad_info_str); + } + } else { + auto* data_ptr = dynamic_cast(t.impl().get()); + if (t.is_initialized() && data_ptr) { + return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE, + tensor_name_str, + t.initialized(), + t.impl(), + tensor_info_str, + *data_ptr, + "None"); + } else { + return paddle::string::Sprintf(TENSOR_PRINT_TEMPLATE, + tensor_name_str, + t.initialized(), + t.impl(), + tensor_info_str, + "None", + "None"); + } + } + } else if (VLOG_IS_ON(6)) { const char* TENSOR_PRINT_TEMPLATE = "{Name: %s, Initialized: %d, Ptr: %d " "TensorInfo: [ %s ], ADInfo:[ %s ]}"; diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 44afc43c046d70fd189064bc699b9853d11e7750..2d333805b5aa02edd43c2bcadf059886d8c27783 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -187,6 +187,7 @@ param : [x, x] kernel : func : cos_double_grad + optional: grad_out backward : cos_triple_grad inplace : (grad_x_grad -> grad_out_grad) @@ -211,6 +212,7 @@ param : [x, x, grad_x_grad_forward] kernel : func : cos_triple_grad + optional: grad_out_forward, grad_x_grad_forward, grad_out_grad_grad inplace : (grad_x_grad_forward -> grad_out_forward_grad) - backward_op : cosh_grad @@ -872,6 +874,7 @@ param : [x, x] kernel : func : sin_double_grad + optional: grad_out backward : sin_triple_grad inplace : (grad_x_grad -> grad_out_grad) @@ -896,6 +899,7 @@ param : [x, x, grad_x_grad_forward] kernel : func : sin_triple_grad + optional: grad_out_forward, grad_x_grad_forward, grad_out_grad_grad inplace : (grad_x_grad_forward -> grad_out_forward_grad) - backward_op : sinh_grad @@ -1054,6 +1058,7 @@ kernel : func : tanh_triple_grad inplace : (grad_x_grad_forward -> grad_out_forward_grad) + optional : grad_out_new_grad, grad_out_grad_grad - backward_op : thresholded_relu_grad forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 064c6b00a88494e02c7f03b7f6374a014a695ecd..b0ce57461685ef137aceb0611d0302ecab240bc6 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -124,7 +124,7 @@ kernel : func : batch_norm_grad_grad data_type : x - optional : out_mean, out_variance + optional : out_mean, out_variance, grad_x_grad, grad_scale_grad, grad_bias_grad inplace : (grad_out -> grad_out_grad) - backward_op : batch_norm_grad @@ -856,7 +856,7 @@ param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y] kernel : func : matmul_triple_grad - optional : grad_x_grad, grad_y_grad, grad_grad_out_grad + optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad - backward_op : max_grad forward: max (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out) @@ -1024,10 +1024,10 @@ output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad) infer_meta : func : GeneralQuinaryGradInferMeta - param : [x, y, fwd_grad_out, x, y] + param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y] kernel : func : multiply_triple_grad - optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_grad_out_grad + optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad - backward_op : nearest_interp_grad forward : nearest_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) -> Tensor(output) diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 847383fc38e94260b9daba84b35b2490820e95d8..56cb316640d1923eaf72282b5bb503431f71d051 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -83,7 +83,7 @@ void ReluDoubleGradKernel(const Context& dev_ctx, template void SinDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& dout, + const paddle::optional& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout); @@ -91,7 +91,7 @@ void SinDoubleGradKernel(const Context& dev_ctx, template void CosDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& dout, + const paddle::optional& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout); @@ -109,8 +109,8 @@ void TanhTripleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, - const DenseTensor& d_dout_new, - const DenseTensor& d_ddout, + const paddle::optional& d_dout_new, + const paddle::optional& d_ddout, DenseTensor* d_out_new, DenseTensor* d_dout, DenseTensor* d_ddx); @@ -118,10 +118,10 @@ void TanhTripleGradKernel(const Context& dev_ctx, template void SinTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& dout, - const DenseTensor& ddx, + const paddle::optional& dout, + const paddle::optional& ddx, const DenseTensor& d_dx_new, - const DenseTensor& d_ddout, + const paddle::optional& d_ddout, DenseTensor* d_x_new, DenseTensor* d_dout, DenseTensor* d_ddx); @@ -129,10 +129,10 @@ void SinTripleGradKernel(const Context& dev_ctx, template void CosTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& dout, - const DenseTensor& ddx, + const paddle::optional& dout, + const paddle::optional& ddx, const DenseTensor& d_dx_new, - const DenseTensor& d_ddout, + const paddle::optional& d_ddout, DenseTensor* d_x_new, DenseTensor* d_dout, DenseTensor* d_ddx); diff --git a/paddle/phi/kernels/batch_norm_grad_kernel.h b/paddle/phi/kernels/batch_norm_grad_kernel.h index 24e23e8d690746d8c9f3e6b9a1bd8276171ffef6..2ef183559099f1082bdb0390f5cb1b282e87dbce 100644 --- a/paddle/phi/kernels/batch_norm_grad_kernel.h +++ b/paddle/phi/kernels/batch_norm_grad_kernel.h @@ -64,25 +64,26 @@ void BatchNormGradKernel(const Context& dev_ctx, DenseTensor* bias_grad); template -void BatchNormDoubleGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& scale, - const paddle::optional& mean, - const paddle::optional& variance, - const DenseTensor& saved_mean, - const DenseTensor& saved_variance, - const DenseTensor& y_grad, - const DenseTensor& x_grad_grad, - const DenseTensor& scale_grad_grad, - const DenseTensor& bias_grad_grad, - float momentum, - float epsilon, - const std::string& data_layout, - bool is_test, - bool use_global_stats, - bool trainable_statistics, - DenseTensor* x_grad, - DenseTensor* scale_grad, - DenseTensor* y_grad_grad); - +void BatchNormDoubleGradKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const paddle::optional& mean, + const paddle::optional& variance, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const DenseTensor& y_grad, + const paddle::optional& x_grad_grad, + const paddle::optional& scale_grad_grad, + const paddle::optional& bias_grad_grad, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* y_grad_grad); } // namespace phi diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index 8d0ae7e08d70bd46b0c8cef833bc1e0d11098514..49555410f99201ebec6adf8b0708c8f0ab4f8b9f 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -334,26 +334,27 @@ void BatchNormGradKernel(const Context& dev_ctx, } template -void BatchNormDoubleGradKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& scale, - const paddle::optional& mean, - const paddle::optional& variance, - const DenseTensor& saved_mean, - const DenseTensor& saved_variance, - const DenseTensor& y_grad, - const DenseTensor& x_grad_grad, - const DenseTensor& scale_grad_grad, - const DenseTensor& bias_grad_grad, - float momentum, - float epsilon, - const std::string& data_layout_str, - bool is_test, - bool use_global_stats, - bool trainable_statistics, - DenseTensor* x_grad, - DenseTensor* scale_grad, - DenseTensor* y_grad_grad) { +void BatchNormDoubleGradKernel( + const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + const paddle::optional& mean, + const paddle::optional& variance, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const DenseTensor& y_grad, + const paddle::optional& x_grad_grad, + const paddle::optional& scale_grad_grad, + const paddle::optional& bias_grad_grad, + float momentum, + float epsilon, + const std::string& data_layout_str, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* y_grad_grad) { const auto* X = &x; const auto* Scale = &scale; const auto* dY = &y_grad; @@ -369,9 +370,9 @@ void BatchNormDoubleGradKernel(const Context& ctx, const auto data_layout = phi::StringToDataLayout(data_layout_str); - const auto* ddX = &x_grad_grad; - const auto* ddScale = &scale_grad_grad; - const auto* ddBias = &bias_grad_grad; + const auto* ddX = x_grad_grad.get_ptr(); + const auto* ddScale = scale_grad_grad.get_ptr(); + const auto* ddBias = bias_grad_grad.get_ptr(); auto* dX = x_grad; auto* dScale = scale_grad; diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index 6571cb2ca8faa91203a0ea33fdcf0e651a735bd2..e7dd6249f3644c53db9a6b394020a6b9c3438d6f 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -108,6 +108,9 @@ PD_REGISTER_KERNEL(full_like, int, int64_t, bool, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/elementwise_multiply_grad_kernel.h b/paddle/phi/kernels/elementwise_multiply_grad_kernel.h index 9cbd5040666cf825a3b65c2bdf64291bbf522841..f175416054086d87c8273235274665ce4b90cf3d 100644 --- a/paddle/phi/kernels/elementwise_multiply_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_multiply_grad_kernel.h @@ -47,8 +47,8 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, const DenseTensor& dout, const paddle::optional& ddx, const paddle::optional& ddy, - const DenseTensor& d_dx, - const DenseTensor& d_dy, + const paddle::optional& d_dx, + const paddle::optional& d_dy, const paddle::optional& d_ddout, int axis, DenseTensor* d_x, diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index ccdff93d5b23c7932ef2255f8fc188be87ed3e49..35970e2b7df91429b1c0f1e9ce8464c54b0d99cf 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -125,14 +125,23 @@ struct SinDoubleGradFunctor : public BaseActivationFunctor { // calculate d2x first, so d2d1y can inplace d2d1x auto d2x = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "d2x", "SinDoubleGrad")); - auto d1y = EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "d1y", "SinDoubleGrad")); - d2x.device(*d) = -d2d1x * x.unaryExpr(Sine()) * d1y; + + if (dX) { + if (dOut) { + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "d1y", "SinDoubleGrad")); + d2x.device(*d) = -d2d1x * x.unaryExpr(Sine()) * d1y; + } else { + d2x.device(*d) = -d2d1x * x.unaryExpr(Sine()) * static_cast(0); + } + } // calculate d2d1y - auto d2d1y = EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "SinDoubleGrad")); - d2d1y.device(*d) = d2d1x * x.unaryExpr(Cosine()); + if (ddOut) { + auto d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "SinDoubleGrad")); + d2d1y.device(*d) = d2d1x * x.unaryExpr(Cosine()); + } } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; @@ -167,28 +176,71 @@ struct SinTripleGradFunctor : public BaseActivationFunctor { auto* d = dev.eigen_device(); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "x", "SinTripleGrad")); - auto d2d1x = EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad")); - auto d1y = EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad")); - auto d3d2d1y = EigenVector::Flatten( - GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad")); auto d3d2x = EigenVector::Flatten( GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "SinTripleGrad")); + if (d_x_New) { + auto d3x = EigenVector::Flatten( + GET_DATA_SAFELY(d_x_New, "Output", "d3x", "SinTripleGrad")); + if (dOut && ddX && d_DDOut) { + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad")); + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad")); + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad")); + d3x.device(*d) = -x.unaryExpr(Cosine()) * d1y * d2d1x * d3d2x - + x.unaryExpr(Sine()) * d2d1x * d3d2d1y; + } else if (!dOut && ddX && d_DDOut) { + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad")); + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad")); + d3x.device(*d) = -x.unaryExpr(Sine()) * d2d1x * d3d2d1y; + } else if (dOut && ddX && !d_DDOut) { + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad")); + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad")); + d3x.device(*d) = -x.unaryExpr(Cosine()) * d1y * d2d1x * d3d2x; + } else { + d3x.device(*d) = x * static_cast(0); + } + } - auto d3x = EigenVector::Flatten( - GET_DATA_SAFELY(d_x_New, "Output", "d3x", "SinTripleGrad")); - d3x.device(*d) = -x.unaryExpr(Cosine()) * d1y * d2d1x * d3d2x - - x.unaryExpr(Sine()) * d2d1x * d3d2d1y; - - auto d3d1y = EigenVector::Flatten( - GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "SinTripleGrad")); - d3d1y.device(*d) = -x.unaryExpr(Sine()) * d2d1x * d3d2x; + if (d_d_Out) { + auto d3d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "SinTripleGrad")); + if (ddX) { + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad")); + d3d1y.device(*d) = -x.unaryExpr(Sine()) * d2d1x * d3d2x; + } else { + d3d1y.device(*d) = static_cast(0) * x; + } + } - auto d3d2d1x = EigenVector::Flatten( - GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "SinTripleGrad")); - d3d2d1x.device(*d) = -x.unaryExpr(Sine()) * d1y * d3d2x + - x.unaryExpr(Cosine()) * d3d2d1y; + if (d_DDx) { + auto d3d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "SinTripleGrad")); + if (dOut && d_DDOut) { + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad")); + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad")); + d3d2d1x.device(*d) = -x.unaryExpr(Sine()) * d1y * d3d2x + + x.unaryExpr(Cosine()) * d3d2d1y; + } else if (dOut && !d_DDOut) { + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad")); + d3d2d1x.device(*d) = -x.unaryExpr(Sine()) * d1y * d3d2x; + } else if (!dOut && d_DDOut) { + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad")); + d3d2d1x.device(*d) = x.unaryExpr(Cosine()) * d3d2d1y; + } else { + d3d2d1x.device(*d) = x * static_cast(0); + } + } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; @@ -270,14 +322,22 @@ struct CosDoubleGradFunctor : public BaseActivationFunctor { // calculate d2x first, so d2d1y can inplace d2d1x auto d2x = EigenVector::Flatten( GET_DATA_SAFELY(dX, "Output", "d2x", "CosDoubleGrad")); - auto d1y = EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "d1y", "CosDoubleGrad")); - d2x.device(*d) = -d2d1x * x.unaryExpr(Cosine()) * d1y; + if (ddOut) { + if (dOut) { + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "d1y", "CosDoubleGrad")); + d2x.device(*d) = -d2d1x * x.unaryExpr(Cosine()) * d1y; + } else { + d2x.device(*d) = x * static_cast(0); + } + } - // calculate d2d1y - auto d2d1y = EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "CosDoubleGrad")); - d2d1y.device(*d) = -d2d1x * x.unaryExpr(Sine()); + if (dX) { + // calculate d2d1y + auto d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "CosDoubleGrad")); + d2d1y.device(*d) = -d2d1x * x.unaryExpr(Sine()); + } } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; @@ -297,28 +357,72 @@ struct CosTripleGradFunctor : public BaseActivationFunctor { auto* d = dev.eigen_device(); auto x = EigenVector::Flatten( GET_DATA_SAFELY(X, "Input", "x", "CosTripleGrad")); - auto d2d1x = EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad")); - auto d1y = EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad")); - auto d3d2d1y = EigenVector::Flatten( - GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad")); auto d3d2x = EigenVector::Flatten( GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "CosTripleGrad")); - auto d3x = EigenVector::Flatten( - GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad")); - d3x.device(*d) = x.unaryExpr(Sine()) * d1y * d2d1x * d3d2x - - x.unaryExpr(Cosine()) * d2d1x * d3d2d1y; + if (d_x_New) { + auto d3x = EigenVector::Flatten( + GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad")); + if (dOut && ddX && d_DDOut) { + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad")); + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad")); + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad")); + d3x.device(*d) = x.unaryExpr(Sine()) * d1y * d2d1x * d3d2x - + x.unaryExpr(Cosine()) * d2d1x * d3d2d1y; + } else if (dOut && ddX && !d_DDOut) { + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad")); + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad")); + d3x.device(*d) = x.unaryExpr(Sine()) * d1y * d2d1x * d3d2x; + } else if (!dOut && ddX && d_DDOut) { + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad")); + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad")); + d3x.device(*d) = -x.unaryExpr(Cosine()) * d2d1x * d3d2d1y; + } else { + d3x.device(*d) = static_cast(0) * x; + } + } - auto d3d1y = EigenVector::Flatten( - GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "CosTripleGrad")); - d3d1y.device(*d) = -x.unaryExpr(Cosine()) * d2d1x * d3d2x; + if (d_d_Out) { + auto d3d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "CosTripleGrad")); + if (ddX) { + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad")); + d3d1y.device(*d) = -x.unaryExpr(Cosine()) * d2d1x * d3d2x; + } else { + d3d1y.device(*d) = static_cast(0) * x; + } + } - auto d3d2d1x = EigenVector::Flatten( - GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "CosTripleGrad")); - d3d2d1x.device(*d) = -x.unaryExpr(Cosine()) * d1y * d3d2x - - x.unaryExpr(Sine()) * d3d2d1y; + if (d_DDx) { + auto d3d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "CosTripleGrad")); + if (dOut && d_DDOut) { + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad")); + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad")); + d3d2d1x.device(*d) = -x.unaryExpr(Cosine()) * d1y * d3d2x - + x.unaryExpr(Sine()) * d3d2d1y; + } else if (!dOut && d_DDOut) { + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad")); + d3d2d1x.device(*d) = -x.unaryExpr(Sine()) * d3d2d1y; + } else if (dOut && !d_DDOut) { + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad")); + d3d2d1x.device(*d) = -x.unaryExpr(Cosine()) * d1y * d3d2x; + } else { + d3d2d1x.device(*d) = static_cast(0) * x; + } + } } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepOut; @@ -1106,27 +1210,70 @@ struct TanhTripleGradFunctor : public BaseActivationFunctor { GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad")); auto dout = EigenVector::Flatten( GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad")); - auto d_ddOut = EigenVector::Flatten( - GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad")); - auto d_dOutNew = EigenVector::Flatten( - GET_DATA_SAFELY(d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); if (d_Out_New) { auto d_OutNew = EigenVector::Flatten( GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad")); - d_OutNew.device(*d) = (static_cast(-2) * out * ddx * d_ddOut) - - (static_cast(2) * dout * ddx * d_dOutNew); + + if (d_DDOut && d_dOut_New) { + auto d_ddOut = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad")); + auto d_dOutNew = EigenVector::Flatten(GET_DATA_SAFELY( + d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); + + d_OutNew.device(*d) = (static_cast(-2) * out * ddx * d_ddOut) - + (static_cast(2) * dout * ddx * d_dOutNew); + + } else if (d_DDOut && !d_dOut_New) { + auto d_ddOut = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad")); + + d_OutNew.device(*d) = (static_cast(-2) * out * ddx * d_ddOut); + + } else if (!d_DDOut && d_dOut_New) { + auto d_dOutNew = EigenVector::Flatten(GET_DATA_SAFELY( + d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); + + d_OutNew.device(*d) = -(static_cast(2) * dout * ddx * d_dOutNew); + } else { + d_OutNew.device(*d) = static_cast(0) * out; + } } if (d_d_Out) { auto d_dOut = EigenVector::Flatten( GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad")); - d_dOut.device(*d) = static_cast(-2) * out * ddx * d_dOutNew; + + if (d_dOut_New) { + auto d_dOutNew = EigenVector::Flatten(GET_DATA_SAFELY( + d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); + d_dOut.device(*d) = static_cast(-2) * out * ddx * d_dOutNew; + } else { + d_dOut.device(*d) = static_cast(0) * out; + } } if (d_DDx) { auto d_ddx = EigenVector::Flatten( GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad")); - d_ddx.device(*d) = (static_cast(1) - (out * out)) * d_ddOut - - static_cast(2) * out * dout * d_dOutNew; + + if (d_DDOut && d_dOut_New) { + auto d_ddOut = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad")); + auto d_dOutNew = EigenVector::Flatten(GET_DATA_SAFELY( + d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); + d_ddx.device(*d) = (static_cast(1) - (out * out)) * d_ddOut - + static_cast(2) * out * dout * d_dOutNew; + + } else if (d_DDOut && !d_dOut_New) { + auto d_ddOut = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad")); + d_ddx.device(*d) = (static_cast(1) - (out * out)) * d_ddOut; + } else if (!d_DDOut && d_dOut_New) { + auto d_dOutNew = EigenVector::Flatten(GET_DATA_SAFELY( + d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad")); + d_ddx.device(*d) = -static_cast(2) * out * dout * d_dOutNew; + } else { + d_ddx.device(*d) = static_cast(0) * ddx; + } } } static constexpr ActBwdOpFwdDeps FwdDeps() { diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 0f4f39629e879313c74d4e9ca80b67769f4e95e1..fd6e92b2ffe06df67056874c545b9264b76a743d 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -1295,26 +1295,27 @@ void BatchNormGradKernel(const Context &dev_ctx, } template -void BatchNormDoubleGradKernel(const Context &ctx, - const DenseTensor &x, - const DenseTensor &scale, - const paddle::optional &mean, - const paddle::optional &variance, - const DenseTensor &saved_mean, - const DenseTensor &saved_variance, - const DenseTensor &y_grad, - const DenseTensor &x_grad_grad, - const DenseTensor &scale_grad_grad, - const DenseTensor &bias_grad_grad, - float momentum, - float epsilon, - const std::string &data_layout_str, - bool is_test, - bool use_global_stats, - bool trainable_statistics, - DenseTensor *x_grad, - DenseTensor *scale_grad, - DenseTensor *y_grad_grad) { +void BatchNormDoubleGradKernel( + const Context &ctx, + const DenseTensor &x, + const DenseTensor &scale, + const paddle::optional &mean, + const paddle::optional &variance, + const DenseTensor &saved_mean, + const DenseTensor &saved_variance, + const DenseTensor &y_grad, + const paddle::optional &x_grad_grad, + const paddle::optional &scale_grad_grad, + const paddle::optional &bias_grad_grad, + float momentum, + float epsilon, + const std::string &data_layout_str, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + DenseTensor *x_grad, + DenseTensor *scale_grad, + DenseTensor *y_grad_grad) { PADDLE_ENFORCE_EQ(is_test, false, phi::errors::InvalidArgument( @@ -1330,23 +1331,24 @@ void BatchNormDoubleGradKernel(const Context &ctx, running_mean = mean.get_ptr(); running_variance = variance.get_ptr(); } - paddle::operators::NormDoubleGradFunctor(ctx, - data_layout, - &x, - &scale, - &y_grad, - &saved_mean, - &saved_variance, - running_mean, - running_variance, - epsilon, - use_global_stats, - &x_grad_grad, - &scale_grad_grad, - &bias_grad_grad, - x_grad, - scale_grad, - y_grad_grad); + paddle::operators::NormDoubleGradFunctor( + ctx, + data_layout, + &x, + &scale, + &y_grad, + &saved_mean, + &saved_variance, + running_mean, + running_variance, + epsilon, + use_global_stats, + x_grad_grad.get_ptr(), + scale_grad_grad.get_ptr(), + bias_grad_grad.get_ptr(), + x_grad, + scale_grad, + y_grad_grad); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index 120e908ae8cf7d628ddedfb6e28df43deceed7a1..4f030bc775b89412239681daffcb3ac76efc4842 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -138,6 +138,8 @@ PD_REGISTER_KERNEL(full_like, int64_t, bool, phi::dtype::bfloat16, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index dd7dadc1e1cf9ec58ae8fa9a66c372531b6b473a..c8f0cddbf75fa0623bcdca7e6df1ee9f9b187f94 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -177,8 +177,8 @@ void TanhTripleGradKernel(const Context& dev_ctx, const DenseTensor& out, const DenseTensor& dout, const DenseTensor& ddx, - const DenseTensor& d_dout_new, - const DenseTensor& d_ddout, + const paddle::optional& d_dout_new, + const paddle::optional& d_ddout, DenseTensor* d_out_new, DenseTensor* d_dout, DenseTensor* d_ddx) { @@ -199,8 +199,8 @@ void TanhTripleGradKernel(const Context& dev_ctx, &out, &ddx, &dout, - &d_ddout, - &d_dout_new, // input + d_ddout.get_ptr(), + d_dout_new.get_ptr(), // input d_dout, d_out_new, d_ddx); // output @@ -597,49 +597,45 @@ void SquareDoubleGradKernel(const Context& dev_ctx, template void SinDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& dout, + const paddle::optional& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { if (dx) { - dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } phi::funcs::SinDoubleGradFunctor functor; - functor(dev_ctx, &x, &dout, &ddx, dx, ddout); + functor(dev_ctx, &x, dout.get_ptr(), &ddx, dx, ddout); } template void SinTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& dout, - const DenseTensor& ddx, + const paddle::optional& dout, + const paddle::optional& ddx, const DenseTensor& d_dx_new, - const DenseTensor& d_ddout, + const paddle::optional& d_ddout, DenseTensor* d_x_new, DenseTensor* d_dout, DenseTensor* d_ddx) { if (d_dout) { - d_dout->Resize(x.dims()); dev_ctx.template Alloc(d_dout); } if (d_x_new) { - d_dout->Resize(x.dims()); dev_ctx.template Alloc(d_x_new); } if (d_ddx) { - d_dout->Resize(ddx.dims()); dev_ctx.template Alloc(d_ddx); } funcs::SinTripleGradFunctor functor; functor(dev_ctx, &x, - &ddx, - &dout, - &d_ddout, + ddx.get_ptr(), + dout.get_ptr(), + d_ddout.get_ptr(), &d_dx_new, // input d_dout, d_x_new, @@ -649,49 +645,45 @@ void SinTripleGradKernel(const Context& dev_ctx, template void CosDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& dout, + const paddle::optional& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { if (dx) { - dx->Resize(x.dims()); dev_ctx.template Alloc(dx); } if (ddout) { dev_ctx.template Alloc(ddout); } phi::funcs::CosDoubleGradFunctor functor; - functor(dev_ctx, &x, &dout, &ddx, dx, ddout); + functor(dev_ctx, &x, dout.get_ptr(), &ddx, dx, ddout); } template void CosTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& dout, - const DenseTensor& ddx, + const paddle::optional& dout, + const paddle::optional& ddx, const DenseTensor& d_dx_new, - const DenseTensor& d_ddout, + const paddle::optional& d_ddout, DenseTensor* d_x_new, DenseTensor* d_dout, DenseTensor* d_ddx) { if (d_dout) { - d_dout->Resize(x.dims()); dev_ctx.template Alloc(d_dout); } if (d_x_new) { - d_dout->Resize(x.dims()); dev_ctx.template Alloc(d_x_new); } if (d_ddx) { - d_dout->Resize(ddx.dims()); dev_ctx.template Alloc(d_ddx); } funcs::CosTripleGradFunctor functor; functor(dev_ctx, &x, - &ddx, - &dout, - &d_ddout, + ddx.get_ptr(), + dout.get_ptr(), + d_ddout.get_ptr(), &d_dx_new, // input d_dout, d_x_new, diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 7759de509af56bdce4a4aa84ac946c6a0f094055..28387975e6e9982dcb3d025d8c3aff2601526fd0 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" @@ -472,6 +473,7 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, funcs::MultiplyFunctor, funcs::InverseMultiplyFunctor>( dev_ctx, y, ddx_safe, ddout, axis); + funcs::DefaultElementwiseOperator, @@ -483,42 +485,70 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, ddout_t.device(place) = ddout_t + ddout_tmp_t; } else { // use dx to save memory, other than alloc tmp tensor - DenseTensor* ddout_tmp = dx; - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, x, ddy_safe, ddout_tmp, axis); - // NOTE: in the following ElemwiseGradCompute, for the - // first output tensor is nullptr, the branch to calculate first - // output tensor will not be activated, DivGradDx function will not - // be called and can be ignored, the first branch has little effect - // on running speed. - phi::funcs::ElemwiseGradCompute, MulGradDY>( - dev_ctx, - ddx_safe, - ddy_safe, - dout, - dout, - axis, - nullptr, - dy, - MulGradDX(), - MulGradDY()); - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, ddx_safe, y, ddout, axis); + if (dx) { + DenseTensor* ddout_tmp = dx; + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, ddy_safe, ddout_tmp, axis); + + // NOTE: in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, + ddx_safe, + ddy_safe, + dout, + dout, + axis, + nullptr, + dy, + MulGradDX(), + MulGradDY()); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddx_safe, y, ddout, axis); - auto ddout_t = phi::EigenVector::Flatten(*ddout); - auto ddout_tmp_t = phi::EigenVector::Flatten(*ddout_tmp); - ddout_t.device(place) = ddout_t + ddout_tmp_t; - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, dout, ddy_safe, dx, axis); + auto ddout_t = phi::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = phi::EigenVector::Flatten(*ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, ddy_safe, dx, axis); + + } else { + DenseTensor tmp_a(ddout->dtype()); + tmp_a.Resize(ddout->dims()); + + dev_ctx.template Alloc(&tmp_a); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, ddy_safe, &tmp_a, axis); + + auto ddout_t1 = phi::EigenVector::Flatten(tmp_a); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddx_safe, y, ddout, axis); + + auto ddout_t2 = phi::EigenVector::Flatten(*ddout); + ddout_t2.device(place) = ddout_t2 + ddout_t1; + } } } else { if (dx && dy) { @@ -544,8 +574,8 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, const DenseTensor& dout, const paddle::optional& ddx, const paddle::optional& ddy, - const DenseTensor& d_dx, - const DenseTensor& d_dy, + const paddle::optional& d_dx, + const paddle::optional& d_dy, const paddle::optional& d_ddout, int axis, DenseTensor* d_x, @@ -599,6 +629,13 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, funcs::InverseMultiplyFunctor>( dev_ctx, ddx_safe, *(d_ddout.get_ptr()), d_y, axis); } + } else { + if (d_x) { + FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), d_x); + } + if (d_y) { + FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), d_y); + } } if (d_dout) { @@ -607,61 +644,135 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, DenseTensor d_dout_tmp; d_dout_tmp.Resize(dout.dims()); dev_ctx.template Alloc(&d_dout_tmp); - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, d_dy, ddx_safe, d_dout, axis); - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, ddy_safe, d_dx, &d_dout_tmp, axis); - auto d_dout_t = phi::EigenVector::Flatten(*d_dout); - auto d_dout_tmp_t = phi::EigenVector::Flatten(d_dout_tmp); - d_dout_t.device(place) = d_dout_t + d_dout_tmp_t; + + if (d_dy && d_dx) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, d_dy.get(), ddx_safe, d_dout, axis); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddy_safe, d_dx.get(), &d_dout_tmp, axis); + + auto d_dout_t = phi::EigenVector::Flatten(*d_dout); + auto d_dout_tmp_t = phi::EigenVector::Flatten(d_dout_tmp); + d_dout_t.device(place) = d_dout_t + d_dout_tmp_t; + } else if (d_dy && !d_dx) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, d_dy.get(), ddx_safe, d_dout, axis); + auto d_dout_t = phi::EigenVector::Flatten(*d_dout); + d_dout_t.device(place) = d_dout_t; + } else if (!d_dy && d_dx) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddy_safe, d_dx.get(), d_dout, axis); + + auto d_dout_t = phi::EigenVector::Flatten(*d_dout); + d_dout_t.device(place) = d_dout_t; + } else { + FullLikeKernel( + dev_ctx, dout, Scalar(0.0), dout.dtype(), d_dout); + } } - if (d_ddx) { + if (d_ddx && ddx) { // get d_ddx // d_ddx = dout * d_dy + y * d_ddout DenseTensor d_ddx_tmp; d_ddx_tmp.Resize(ddx->dims()); dev_ctx.template Alloc(&d_ddx_tmp); - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, dout, d_dy, d_ddx, axis); - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis); - auto d_ddx_t = phi::EigenVector::Flatten(*d_ddx); - auto d_ddx_tmp_t = phi::EigenVector::Flatten(d_ddx_tmp); - d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t; + if (d_dy && d_ddout) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, d_dy.get(), d_ddx, axis); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis); + + auto d_ddx_t = phi::EigenVector::Flatten(*d_ddx); + auto d_ddx_tmp_t = phi::EigenVector::Flatten(d_ddx_tmp); + d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t; + } else if (d_dy && !d_ddout) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, d_dy.get(), d_ddx, axis); + + auto d_ddx_t = phi::EigenVector::Flatten(*d_ddx); + d_ddx_t.device(place) = d_ddx_t; + } else if (!d_dy && d_ddout) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, y, *(d_ddout.get_ptr()), d_ddx, axis); + + auto d_ddx_t = phi::EigenVector::Flatten(*d_ddx); + d_ddx_t.device(place) = d_ddx_t; + } else { + FullLikeKernel(dev_ctx, x, Scalar(0.0), x.dtype(), d_ddx); + } } - if (d_ddy) { + if (d_ddy && ddy) { // get d_ddy // d_ddy = dout * d_dx + x * d_ddout DenseTensor d_ddy_tmp; d_ddy_tmp.Resize(ddy->dims()); dev_ctx.template Alloc(&d_ddy_tmp); - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, dout, d_dx, d_ddy, axis); - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis); - auto d_ddy_t = phi::EigenVector::Flatten(*d_ddy); - auto d_ddy_tmp_t = phi::EigenVector::Flatten(d_ddy_tmp); - d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t; + + if (d_dx && d_ddout) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, d_dx.get(), d_ddy, axis); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis); + + auto d_ddy_t = phi::EigenVector::Flatten(*d_ddy); + auto d_ddy_tmp_t = phi::EigenVector::Flatten(d_ddy_tmp); + d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t; + } else if (d_dx && !d_ddout) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, d_dx.get(), d_ddy, axis); + + auto d_ddy_t = phi::EigenVector::Flatten(*d_ddy); + d_ddy_t.device(place) = d_ddy_t; + } else if (!d_dx && d_ddout) { + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, *(d_ddout.get_ptr()), d_ddy, axis); + + auto d_ddy_t = phi::EigenVector::Flatten(*d_ddy); + d_ddy_t.device(place) = d_ddy_t; + } else { + FullLikeKernel(dev_ctx, y, Scalar(0.0), y.dtype(), d_ddy); + } } } diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index 85a530b1b7559733a1393a9429969913da5eee23..0eeae849bcfedff487d4db41cc80f24a7488ea1b 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -15,6 +15,7 @@ #pragma once #include + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h index f499e59c307291f91de5b25b72fd4a12b20d68a6..1bc29a34d46e1be8d10c6fcb9f4faca3cd0ad006 100644 --- a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/reduce_functor.h" #include "paddle/phi/kernels/impl/dot_grad_kernel_impl.h" @@ -262,6 +263,7 @@ void MatmulGradKernel(const Context& dev_ctx, DenseTensor x_help = x; DenseTensor y_help = y; DenseTensor out_grad_help = out_grad; + ReshapeXYOutIntoMatrixSequence( &x_help, &y_help, &out_grad_help, transpose_x, transpose_y); @@ -471,13 +473,27 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& dout, - const paddle::optional& ddx, - const paddle::optional& ddy, + const paddle::optional& ddx_opt, + const paddle::optional& ddy_opt, bool transpose_x, bool transpose_y, DenseTensor* dx, DenseTensor* dy, DenseTensor* ddout) { + paddle::optional ddx; + paddle::optional ddy; + if (!ddx_opt && (dy || ddout)) { + DenseTensor ddx_tmp = phi::FullLike(dev_ctx, x, Scalar(0.0)); + ddx = paddle::make_optional(ddx_tmp); + } else { + ddx = ddx_opt; + } + if (!ddy_opt && (dx || ddout)) { + DenseTensor ddy_tmp = phi::FullLike(dev_ctx, y, Scalar(0.0)); + ddy = paddle::make_optional(ddy_tmp); + } else { + ddy = ddy_opt; + } // Get dims from the input x, y, output_grad std::vector x_dims = vectorize(x.dims()); std::vector y_dims = vectorize(y.dims()); @@ -688,7 +704,7 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, if (transpose_x) { if (transpose_y) { - if (dx) { + if (dx && ddy) { MatMulFunction(dev_ctx, ddy.get(), dout_conj, @@ -698,7 +714,7 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, true, true); } - if (dy) { + if (dy && ddx) { MatMulFunction(dev_ctx, dout_conj, ddx.get(), @@ -709,7 +725,7 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, true); } } else { - if (dx) + if (dx && ddy) { MatMulFunction(dev_ctx, ddy.get(), dout_conj, @@ -718,7 +734,8 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, &dx_help, false, true); - if (dy) + } + if (dy && ddx) { MatMulFunction(dev_ctx, ddx.get(), dout_conj, @@ -727,10 +744,11 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, &dy_help, false, false); + } } } else { if (transpose_y) { - if (dx) { + if (dx && ddy) { MatMulFunction(dev_ctx, dout_conj, ddy.get(), @@ -740,7 +758,7 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, false, false); } - if (dy) { + if (dy && ddx) { MatMulFunction(dev_ctx, dout_conj, ddx.get(), @@ -751,7 +769,7 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, false); } } else { - if (dx) { + if (dx && ddy) { MatMulFunction(dev_ctx, dout_conj, ddy.get(), @@ -761,7 +779,7 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, false, true); } - if (dy) { + if (dy && ddx) { MatMulFunction(dev_ctx, ddx.get(), dout_conj, @@ -824,23 +842,28 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, if (ddout) { // Calculate the gradient of OutputGrad(Out) - MatMulFunction(dev_ctx, - ddx.get(), - y_conj, - x_dims, - y_dims, - ddout, - transpose_x, - transpose_y); - MatMulFunction(dev_ctx, - x_conj, - ddy.get(), - x_dims, - y_dims, - ddout, - transpose_x, - transpose_y, - true); + if (ddx) { + MatMulFunction(dev_ctx, + ddx.get(), + y_conj, + x_dims, + y_dims, + ddout, + transpose_x, + transpose_y); + } + + if (ddy) { + MatMulFunction(dev_ctx, + x_conj, + ddy.get(), + x_dims, + y_dims, + ddout, + transpose_x, + transpose_y, + true); + } } } } @@ -850,11 +873,11 @@ void MatmulTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& dout, - const DenseTensor& ddx, - const DenseTensor& ddy, - const paddle::optional& d_dx, - const paddle::optional& d_dy, - const paddle::optional& d_ddout, + const paddle::optional& ddx_opt, + const paddle::optional& ddy_opt, + const paddle::optional& d_dx_opt, + const paddle::optional& d_dy_opt, + const paddle::optional& d_ddout_opt, bool transpose_x, bool transpose_y, DenseTensor* out_d_x, @@ -862,6 +885,50 @@ void MatmulTripleGradKernel(const Context& dev_ctx, DenseTensor* out_d_dout, DenseTensor* out_d_ddx, DenseTensor* out_d_ddy) { + paddle::optional ddx; + paddle::optional ddy; + paddle::optional d_dx; + paddle::optional d_dy; + paddle::optional d_ddout; + + if (!ddx_opt && (out_d_y || out_d_dout)) { + DenseTensor ddx_tmp = + phi::FullLike(dev_ctx, x, static_cast(0.0)); + ddx = paddle::make_optional(ddx_tmp); + } else { + ddx = ddx_opt; + } + if (!ddy_opt && (out_d_x || out_d_dout)) { + DenseTensor ddy_tmp = + phi::FullLike(dev_ctx, y, static_cast(0.0)); + ddy = paddle::make_optional(ddy_tmp); + } else { + ddy = ddy_opt; + } + + if (!d_ddout_opt && (out_d_y || out_d_x || out_d_ddy || out_d_ddx)) { + DenseTensor d_ddout_tmp = + phi::FullLike(dev_ctx, dout, static_cast(0.0)); + d_ddout = paddle::make_optional(d_ddout_tmp); + } else { + d_ddout = d_ddout_opt; + } + + if (!d_dx_opt && (out_d_ddy || out_d_dout)) { + DenseTensor d_dx_tmp = + phi::FullLike(dev_ctx, x, static_cast(0.0)); + d_dx = paddle::make_optional(d_dx_tmp); + } else { + d_dx = d_dx_opt; + } + + if (!d_dy_opt && (out_d_ddx || out_d_dout)) { + DenseTensor d_dy_tmp = + phi::FullLike(dev_ctx, y, static_cast(0.0)); + d_dy = paddle::make_optional(d_dy_tmp); + } else { + d_dy = d_dy_opt; + } // Get dims from the input x, y, output_grad std::vector x_dims = vectorize(x.dims()); std::vector y_dims = vectorize(y.dims()); @@ -877,8 +944,8 @@ void MatmulTripleGradKernel(const Context& dev_ctx, DotTripleGradFunction()(dev_ctx, &x, &y, - &ddx, - &ddy, + ddx.get_ptr(), + ddy.get_ptr(), d_dx.get_ptr(), d_dy.get_ptr(), &dout, @@ -913,17 +980,23 @@ void MatmulTripleGradKernel(const Context& dev_ctx, DenseTensor x_help = x; DenseTensor y_help = y; DenseTensor dout_help = dout; - DenseTensor ddx_help = ddx; - DenseTensor ddy_help = ddy; + + DenseTensor ddx_help; + DenseTensor ddy_help; ReshapeXYOutIntoMatrixSequence( &x_help, &y_help, &dout_help, transpose_x, transpose_y); - - if (ddx_help.dims() != x_help.dims()) { - ddx_help.Resize(x_help.dims()); + if (ddx) { + ddx_help = ddx.get(); + if (ddx_help.dims() != x_help.dims()) { + ddx_help.Resize(x_help.dims()); + } } - if (ddy_help.dims() != y_help.dims()) { - ddy_help.Resize(y_help.dims()); + if (ddy) { + ddy_help = ddy.get(); + if (ddy_help.dims() != y_help.dims()) { + ddy_help.Resize(y_help.dims()); + } } DDim out_dx_dims; @@ -932,60 +1005,64 @@ void MatmulTripleGradKernel(const Context& dev_ctx, if (out_dx_dims != x_help.dims()) { out_d_x->Resize(x_help.dims()); } + if (ddy) { + ddy_conj = Conj(dev_ctx, ddy_help); + } } - DDim out_dy_dims; if (out_d_y) { out_dy_dims = out_d_y->dims(); if (out_dy_dims != y_help.dims()) { out_d_y->Resize(y_help.dims()); } + if (ddx) { + ddx_conj = Conj(dev_ctx, ddx_help); + } } - DDim out_d_dout_dims; if (out_d_dout) { out_d_dout_dims = out_d_dout->dims(); if (out_d_dout_dims != dout_help.dims()) { out_d_dout->Resize(dout_help.dims()); } - - ddx_conj = Conj(dev_ctx, ddx_help); - ddy_conj = Conj(dev_ctx, ddy_help); + if (ddx && !ddx_conj.IsInitialized()) { + ddx_conj = Conj(dev_ctx, ddx_help); + } + if (ddy && !ddy_conj.IsInitialized()) { + ddy_conj = Conj(dev_ctx, ddy_help); + } } - DDim out_d_ddx_dims; if (out_d_ddx) { out_d_ddx_dims = out_d_ddx->dims(); if (out_d_ddx_dims != x_help.dims()) { out_d_ddx->Resize(x_help.dims()); } + dout_conj = Conj(dev_ctx, dout_help); + y_conj = Conj(dev_ctx, y_help); } - DDim out_d_ddy_dims; if (out_d_ddy) { out_d_ddy_dims = out_d_ddy->dims(); if (out_d_ddy_dims != y_help.dims()) { out_d_ddy->Resize(y_help.dims()); } - } - - if (out_d_ddx || out_d_ddy) { + if (dout_conj.IsInitialized()) { + dout_conj = Conj(dev_ctx, dout_help); + } x_conj = Conj(dev_ctx, x_help); - y_conj = Conj(dev_ctx, y_help); - dout_conj = Conj(dev_ctx, dout_help); } bool d_dout_flag = false; bool d_ddx_flag = false; bool d_ddy_flag = false; - if (d_ddout) { auto d_ddout_mat = d_ddout.get(); if (d_ddout_mat.dims() != dout_help.dims()) { d_ddout_mat.Resize(dout_help.dims()); } - if (out_d_y) { + if (out_d_y && ddx) { if (transpose_x && transpose_y) { // out_d_y = d_ddout' * ddx' CalcInputGrad(dev_ctx, @@ -1032,7 +1109,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, false); } } - if (out_d_x) { + if (out_d_x && ddy) { if (transpose_x && transpose_y) { // out_d_x = ddy' * d_ddout' CalcInputGrad(dev_ctx, @@ -1201,7 +1278,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, } // compute d_dout1 - if (out_d_dout) { + if (out_d_dout && ddx) { CalcInputGrad(dev_ctx, ddx_conj, transpose_x, @@ -1271,7 +1348,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, } // compute d_dout2 - if (out_d_dout) { + if (out_d_dout && ddy) { CalcInputGrad(dev_ctx, d_dx_mat, transpose_x, @@ -1376,8 +1453,12 @@ void MatmulTripleGradKernel(const Context& dev_ctx, DenseTensor out_d_ddy_help; if (out_d_dout) { - ddx_conj = Conj(dev_ctx, ddx); - ddy_conj = Conj(dev_ctx, ddy); + if (ddx) { + ddx_conj = Conj(dev_ctx, ddx.get()); + } + if (ddy) { + ddy_conj = Conj(dev_ctx, ddy.get()); + } } if (out_d_ddx || out_d_ddy) { x_conj = Conj(dev_ctx, x); @@ -1388,7 +1469,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, if (transpose_x) { if (transpose_y) { // dX = ddY' d_ddout’, dY = d_ddout’ ddX' - if (out_d_x) + if (out_d_x && ddy && d_ddout) MatMulFunction(dev_ctx, ddy_conj, d_ddout.get(), @@ -1397,7 +1478,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, &out_dx_help, true, true); - if (out_d_y) + if (out_d_y && ddx && d_ddout) MatMulFunction(dev_ctx, d_ddout.get(), ddx_conj, @@ -1408,7 +1489,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, true); } else { // dX = ddY d_ddout', dY = ddX d_ddout - if (out_d_x) + if (out_d_x && ddy && d_ddout) MatMulFunction(dev_ctx, ddy_conj, d_ddout.get(), @@ -1417,7 +1498,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, &out_dx_help, false, true); - if (out_d_y) + if (out_d_y && ddx && d_ddout) MatMulFunction(dev_ctx, ddx_conj, d_ddout.get(), @@ -1427,10 +1508,11 @@ void MatmulTripleGradKernel(const Context& dev_ctx, false, false); } + } else { if (transpose_y) { // dX = d_ddout ddY, dY = d_ddout’ ddX - if (out_d_x) + if (out_d_x && ddy && d_ddout) MatMulFunction(dev_ctx, d_ddout.get(), ddy_conj, @@ -1439,7 +1521,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, &out_dx_help, false, false); - if (out_d_y) + if (out_d_y && ddx && d_ddout) MatMulFunction(dev_ctx, d_ddout.get(), ddx_conj, @@ -1450,7 +1532,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, false); } else { // dX = d_ddout ddY', dY = ddX' d_ddout - if (out_d_x) + if (out_d_x && ddy && d_ddout) MatMulFunction(dev_ctx, d_ddout.get(), ddy_conj, @@ -1459,7 +1541,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, &out_dx_help, false, true); - if (out_d_y) + if (out_d_y && ddx && d_ddout) MatMulFunction(dev_ctx, ddx_conj, d_ddout.get(), @@ -1501,6 +1583,7 @@ void MatmulTripleGradKernel(const Context& dev_ctx, dy_reduce_dims.push_back(idx); } } + // Reduce sum to get grad by ReduceSum if (out_d_x) { if (dx_reduce_dims.empty()) { @@ -1524,107 +1607,135 @@ void MatmulTripleGradKernel(const Context& dev_ctx, // compute d_dout if (out_d_dout) { - MatMulFunction(dev_ctx, - d_dx.get(), - ddy_conj, - x_dims, - y_dims, - out_d_dout, - transpose_x, - transpose_y); - MatMulFunction(dev_ctx, - ddx_conj, - d_dy.get(), - x_dims, - y_dims, - out_d_dout, - transpose_x, - transpose_y, - true); - } - // compute d_ddx - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' + if (d_dx && ddy) { MatMulFunction(dev_ctx, - y_conj, - d_ddout.get(), + d_dx.get(), + ddy_conj, + x_dims, y_dims, - dout_dims, - &out_d_ddx_help, - true, - true); - // out_d_ddx2 = D_DY' * DOut' + out_d_dout, + transpose_x, + transpose_y); + } + if (d_dy && ddx) { MatMulFunction(dev_ctx, + ddx_conj, d_dy.get(), - dout_conj, + x_dims, y_dims, - dout_dims, - &out_d_ddx_help, - true, - true, + out_d_dout, + transpose_x, + transpose_y, true); + } + } + + // compute d_ddx + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + if (d_ddout) { + MatMulFunction(dev_ctx, + y_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_d_ddx_help, + true, + true); + } + + // out_d_ddx2 = D_DY' * DOut' + if (d_dy) { + MatMulFunction(dev_ctx, + d_dy.get(), + dout_conj, + y_dims, + dout_dims, + &out_d_ddx_help, + true, + true, + true); + } + } else if (transpose_x) { // out_d_ddx1 = y * d_ddout' - MatMulFunction(dev_ctx, - y_conj, - d_ddout.get(), - y_dims, - dout_dims, - &out_d_ddx_help, - false, - true); + if (d_ddout) { + MatMulFunction(dev_ctx, + y_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_d_ddx_help, + false, + true); + } + // out_d_ddx2 = D_DY * Dout' - MatMulFunction(dev_ctx, - d_dy.get(), - dout_conj, - y_dims, - dout_dims, - &out_d_ddx_help, - false, - true, - true); + if (d_dy) { + MatMulFunction(dev_ctx, + d_dy.get(), + dout_conj, + y_dims, + dout_dims, + &out_d_ddx_help, + false, + true, + true); + } + } else if (transpose_y) { // out_d_ddx1 = d_ddout * y - MatMulFunction(dev_ctx, - d_ddout.get(), - y_conj, - dout_dims, - y_dims, - &out_d_ddx_help, - false, - false); + if (d_ddout) { + MatMulFunction(dev_ctx, + d_ddout.get(), + y_conj, + dout_dims, + y_dims, + &out_d_ddx_help, + false, + false); + } + // out_d_ddx2 = Dout * D_DY - MatMulFunction(dev_ctx, - dout_conj, - d_dy.get(), - dout_dims, - y_dims, - &out_d_ddx_help, - false, - false, - true); + if (d_dy) { + MatMulFunction(dev_ctx, + dout_conj, + d_dy.get(), + dout_dims, + y_dims, + &out_d_ddx_help, + false, + false, + true); + } } else { // out_d_ddx1 = d_ddout * y' - MatMulFunction(dev_ctx, - d_ddout.get(), - y_conj, - dout_dims, - y_dims, - &out_d_ddx_help, - false, - true); + if (d_ddout) { + MatMulFunction(dev_ctx, + d_ddout.get(), + y_conj, + dout_dims, + y_dims, + &out_d_ddx_help, + false, + true); + } + // out_d_ddx2 = Dout * D_DY' - MatMulFunction(dev_ctx, - dout_conj, - d_dy.get(), - dout_dims, - y_dims, - &out_d_ddx_help, - false, - true, - true); + if (d_dy) { + MatMulFunction(dev_ctx, + dout_conj, + d_dy.get(), + dout_dims, + y_dims, + &out_d_ddx_help, + false, + true, + true); + } } + if (dx_reduce_dims.empty()) { *out_d_ddx = std::move(out_d_ddx_help); } else { @@ -1638,84 +1749,107 @@ void MatmulTripleGradKernel(const Context& dev_ctx, if (out_d_ddy) { if (transpose_x && transpose_y) { // out_d_ddy1 = d_ddout' * x' - MatMulFunction(dev_ctx, - d_ddout.get(), - x_conj, - dout_dims, - x_dims, - &out_d_ddy_help, - true, - true); + if (d_ddout) { + MatMulFunction(dev_ctx, + d_ddout.get(), + x_conj, + dout_dims, + x_dims, + &out_d_ddy_help, + true, + true); + } + // out_d_ddy2 = dout' * d_dx' - MatMulFunction(dev_ctx, - dout_conj, - d_dx.get(), - dout_dims, - x_dims, - &out_d_ddy_help, - true, - true, - true); + if (d_dx) { + MatMulFunction(dev_ctx, + dout_conj, + d_dx.get(), + dout_dims, + x_dims, + &out_d_ddy_help, + true, + true, + true); + } + } else if (transpose_x) { // out_d_ddy1 = x * d_ddout - MatMulFunction(dev_ctx, - x_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_d_ddy_help, - false, - false); + if (d_ddout) { + MatMulFunction(dev_ctx, + x_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_d_ddy_help, + false, + false); + } + // out_d_ddy2 = d_dx * dout - MatMulFunction(dev_ctx, - d_dx.get(), - dout_conj, - x_dims, - dout_dims, - &out_d_ddy_help, - false, - false, - true); + if (d_dx) { + MatMulFunction(dev_ctx, + d_dx.get(), + dout_conj, + x_dims, + dout_dims, + &out_d_ddy_help, + false, + false, + true); + } + } else if (transpose_y) { // out_d_ddy1 = d_ddout' * x - MatMulFunction(dev_ctx, - d_ddout.get(), - x_conj, - dout_dims, - x_dims, - &out_d_ddy_help, - true, - false); + if (d_ddout) { + MatMulFunction(dev_ctx, + d_ddout.get(), + x_conj, + dout_dims, + x_dims, + &out_d_ddy_help, + true, + false); + } + // out_d_ddy2 = dout' * d_dx - MatMulFunction(dev_ctx, - dout_conj, - d_dx.get(), - dout_dims, - x_dims, - &out_d_ddy_help, - true, - false, - true); + if (d_dx) { + MatMulFunction(dev_ctx, + dout_conj, + d_dx.get(), + dout_dims, + x_dims, + &out_d_ddy_help, + true, + false, + true); + } + } else { // out_d_ddy1 = x' * d_ddout - MatMulFunction(dev_ctx, - x_conj, - d_ddout.get(), - x_dims, - dout_dims, - &out_d_ddy_help, - true, - false); + if (d_ddout) { + MatMulFunction(dev_ctx, + x_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_d_ddy_help, + true, + false); + } + // out_d_ddy2 = d_dx' * dout - MatMulFunction(dev_ctx, - d_dx.get(), - dout_conj, - x_dims, - dout_dims, - &out_d_ddy_help, - true, - false, - true); + if (d_dx) { + MatMulFunction(dev_ctx, + d_dx.get(), + dout_conj, + x_dims, + dout_dims, + &out_d_ddy_help, + true, + false, + true); + } } if (dy_reduce_dims.empty()) { diff --git a/paddle/phi/kernels/logcumsumexp_grad_kernel.h b/paddle/phi/kernels/logcumsumexp_grad_kernel.h index e78a79550657eb67c07c8bbd5a34b2e5e4e9f3dd..a16dc5318cb1ffe5070a0d4507d1eacbe4abe7f8 100644 --- a/paddle/phi/kernels/logcumsumexp_grad_kernel.h +++ b/paddle/phi/kernels/logcumsumexp_grad_kernel.h @@ -28,4 +28,4 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, bool exclusive, bool reverse, DenseTensor* d_x); -} +} // namespace phi diff --git a/paddle/phi/kernels/matmul_grad_kernel.h b/paddle/phi/kernels/matmul_grad_kernel.h index 47c6acdcb392309e3ca7849298f89b0e5cf9ef1e..572b58eb0ddc64f2119930629a77fe9feb51422c 100644 --- a/paddle/phi/kernels/matmul_grad_kernel.h +++ b/paddle/phi/kernels/matmul_grad_kernel.h @@ -47,8 +47,8 @@ void MatmulTripleGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const DenseTensor& dout, - const DenseTensor& ddx, - const DenseTensor& ddy, + const paddle::optional& ddx, + const paddle::optional& ddy, const paddle::optional& d_dx, const paddle::optional& d_dy, const paddle::optional& d_ddout, diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index c5fca8881e221e37e04af4c3d116173c790b7e76..ab3e0344478a467c209b3836dd1008ac9b2ce24d 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -140,6 +140,8 @@ PD_REGISTER_KERNEL(full_like, float, int, int64_t, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); }