From 69a4a39f92a4b832e885d219c332cf77b6320b49 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Sat, 9 Jul 2022 12:56:55 +0800 Subject: [PATCH] merge develop (#43995) --- paddle/fluid/eager/api/manual/CMakeLists.txt | 1 + .../manual/eager_manual/dygraph_forward_api.h | 3 + .../eager_manual/forwards/CMakeLists.txt | 9 +- .../eager_manual/forwards/add_n_fwd_func.cc | 109 ++++++++++++++++++ .../manual/eager_manual/nodes/CMakeLists.txt | 7 +- .../manual/eager_manual/nodes/add_n_node.cc | 78 +++++++++++++ .../api/manual/eager_manual/nodes/nodes.h | 47 ++++++++ .../final_state_generator/eager_gen.py | 6 +- paddle/phi/api/lib/api_custom_impl.cc | 43 ------- paddle/phi/api/lib/api_custom_impl.h | 4 - paddle/phi/api/yaml/legacy_backward.yaml | 7 -- .../fluid/tests/unittests/test_sum_op.py | 23 ++++ 12 files changed, 280 insertions(+), 57 deletions(-) create mode 100644 paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc create mode 100644 paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc diff --git a/paddle/fluid/eager/api/manual/CMakeLists.txt b/paddle/fluid/eager/api/manual/CMakeLists.txt index e6db90ccc5b..8c4ce6d2bdb 100644 --- a/paddle/fluid/eager/api/manual/CMakeLists.txt +++ b/paddle/fluid/eager/api/manual/CMakeLists.txt @@ -6,6 +6,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) set(fluid_manual_nodes ${fluid_manual_nodes} PARENT_SCOPE) + add_subdirectory(eager_manual) set(eager_manual_functions ${eager_manual_functions} diff --git a/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h b/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h index 0f068310681..f9d10600a9a 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h +++ b/paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h @@ -16,6 +16,9 @@ #include "paddle/phi/api/include/tensor.h" +paddle::experimental::Tensor add_n_final_state_dygraph_function( + const std::vector& x); + paddle::experimental::Tensor conv2d_final_state_dygraph_function( const paddle::experimental::Tensor& input, const paddle::experimental::Tensor& filter, diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/CMakeLists.txt b/paddle/fluid/eager/api/manual/eager_manual/forwards/CMakeLists.txt index 0ed2f26c0b2..d71f1153e2f 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/CMakeLists.txt +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/CMakeLists.txt @@ -1,3 +1,10 @@ +cc_library( + add_n_fwd_func + SRCS add_n_fwd_func.cc + DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) + +add_dependencies(add_n_fwd_func eager_codegen) + cc_library( conv2d_fwd_function SRCS conv2d_fwd_function.cc @@ -6,5 +13,5 @@ cc_library( add_dependencies(conv2d_fwd_function eager_codegen) set(eager_manual_functions - conv2d_fwd_function + conv2d_fwd_function add_n_fwd_func PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc new file mode 100644 index 00000000000..226197b0f84 --- /dev/null +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/amp_utils.h" +#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" +#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/eager_amp_auto_cast.h" +#include "paddle/fluid/eager/nan_inf_utils.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" + +#pragma GCC diagnostic ignored "-Wunused-variable" +DECLARE_bool(check_nan_inf); + +paddle::experimental::Tensor add_n_final_state_dygraph_function( + const std::vector& x) { + // Dygraph Record Event + paddle::platform::RecordEvent dygraph_entrance_record_event( + "add_n dygraph", paddle::platform::TracerEventType::Operator, 1); + + // AMP Logic + if (egr::Controller::Instance().GetAMPLevel() != + paddle::imperative::AmpLevel::O0) { + VLOG(5) << "Check and Prepare For AMP"; + auto op_name = phi::TransToFluidOpName("add_n"); + paddle::small_vector, + egr::kSlotSmallVectorSize> + amp_tensors_vector = {x}; + + auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector); + + auto NEW_x = egr::EagerAmpAutoCasts("x", x, amp_dst_dtype, op_name); + + { + paddle::imperative::AutoCastGuard guard( + egr::Controller::Instance().GetCurrentTracer(), + paddle::imperative::AmpLevel::O0); + return add_n_final_state_dygraph_function(NEW_x); + } + } + + // Get Input AutoGradMeta + std::vector x_autograd_meta_vec = + egr::EagerUtils::nullable_autograd_meta(x); + std::vector* x_autograd_meta = &x_autograd_meta_vec; + // Forward API Call + VLOG(3) << "Final State Running: " + << "add_n_final_state_dygraph_function"; + auto api_result = paddle::experimental::add_n(x); + // Check NaN and Inf if needed + if (FLAGS_check_nan_inf) { + egr::CheckTensorHasNanOrInf("add_n", api_result); + } + + // Get Outputs + auto& out = api_result; + + // Get Output AutoGradMeta + egr::AutogradMeta* out_autograd_meta = egr::EagerUtils::autograd_meta(&out); + bool trace_backward = egr::Controller::Instance().HasGrad(); + bool require_any_grad = + egr::EagerUtils::ComputeRequireGrad(trace_backward, x_autograd_meta); + + // Check Inplace if needed + + // Node Creation + if (require_any_grad) { + paddle::platform::RecordEvent node_creation_record_event( + "add_n node_creation", + paddle::platform::TracerEventType::OperatorInner, + 1); + + egr::EagerUtils::PassStopGradient(false, out_autograd_meta); + + // Node Construction + auto grad_node = + std::shared_ptr(new AddNGradNodeFinal(1, 1)); + // SetAttributes if needed + + // Set TensorWrappers for Forward Inputs if needed + grad_node->SetTensorWrapperx(x); + // SetGradOutMeta & SetEdges + grad_node->SetGradOutMeta(x, 0); + // SetOutRank & SetHistory & SetGradInMeta & RetainGrad + if (out_autograd_meta) { + egr::EagerUtils::SetOutRankWithSlot(out_autograd_meta, 0); + } + if (out_autograd_meta) { + egr::EagerUtils::SetHistory(out_autograd_meta, grad_node); + } + grad_node->SetGradInMeta(out, 0); + egr::EagerUtils::CheckAndRetainGrad(out); + // Set TensorWrappers for Forward Outputs if needed + } + + // Returns + return out; +} diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt b/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt index 21642fbd649..fa6a9a53aba 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/CMakeLists.txt @@ -1,8 +1,13 @@ +cc_library( + add_n_node + SRCS add_n_node.cc + DEPS ${eager_deps} ${fluid_deps}) + cc_library( conv2d_nodes SRCS conv2d_nodes.cc DEPS ${eager_deps} ${fluid_deps}) set(eager_manual_nodes - conv2d_nodes + conv2d_nodes add_n_node PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc new file mode 100644 index 00000000000..e314c0c2b5b --- /dev/null +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" +#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/nan_inf_utils.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/phi/api/all.h" +#include "paddle/phi/api/lib/api_custom_impl.h" +DECLARE_bool(check_nan_inf); + +paddle::small_vector, + egr::kSlotSmallVectorSize> +AddNGradNodeFinal::operator()( + paddle::small_vector, + egr::kSlotSmallVectorSize>& grads, + bool create_graph, + bool is_new_grad) { + // Fill Zero For GradIn Tensors + + // Apply Gradient Hooks + auto hooked_grads = ApplyGradientHooks(grads); + + // Collect GradIn Tensors, Attrs and Recovered TensorWrappers + auto x = egr::EagerUtils::RecoverTensorWrapper(&this->x_); + auto& out_grad = hooked_grads[0][0]; + // Prepare Grad function call + + const auto& out_metas = OutputMeta(); + paddle::small_vector, + egr::kSlotSmallVectorSize> + returns(1); + for (int i = 0; i < 1; ++i) { + out_metas[i].size() == 0 ? returns[i].resize(1) + : returns[i].resize(out_metas[i].size()); + } + + std::vector api_output_0; + api_output_0.reserve(returns[0].size()); + for (size_t i = 0; i < returns[0].size(); ++i) { + if (out_metas[0].empty() || out_metas[0][i].IsStopGradient()) { + api_output_0.push_back(nullptr); + } else { + api_output_0.push_back(&returns[0][i]); + } + } + // Call grad_api function + VLOG(3) << "Final State Running: AddNGradNodeFinal"; + + // dygraph function + for (size_t i = 0; i < returns[0].size(); i++) { + returns[0][i] = ::scale_final_state_dygraph_function( + out_grad, phi::Scalar(1.0), 0.0, true); + } + + // Check NaN and Inf id needed + if (FLAGS_check_nan_inf) { + egr::CheckTensorHasNanOrInf("add_n_grad", returns); + } + + if (NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns); + return returns; +} diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h index f202b64f0b7..14fe144c009 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/tensor_wrapper.h" +#include "paddle/fluid/imperative/tracer.h" class Conv2dGradNodeFinal : public egr::GradNodeBase { public: @@ -180,3 +181,49 @@ class Conv2dDoubleGradNodeFinal : public egr::GradNodeBase { int workspace_size_MB_; bool exhaustive_search_; }; + +class AddNGradNodeFinal : public egr::GradNodeBase { + public: + AddNGradNodeFinal() : egr::GradNodeBase() {} + AddNGradNodeFinal(size_t bwd_in_slot_num, size_t bwd_out_slot_num) + : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {} + ~AddNGradNodeFinal() override = default; + + virtual paddle::small_vector, + egr::kSlotSmallVectorSize> + operator()( + paddle::small_vector, // NOLINT + egr::kSlotSmallVectorSize>& grads, // NOLINT + bool create_graph = false, + bool is_new_grad = false) override; + std::string name() override { return "AddNGradNodeFinal"; } + + void ClearTensorWrappers() override { + for (auto& tw : x_) { + tw.clear(); + } + + SetIsTensorWrappersCleared(true); + } + + std::shared_ptr Copy() const override { + auto copied_node = + std::shared_ptr(new AddNGradNodeFinal(*this)); + return copied_node; + } + + // SetTensorWrapperX, SetTensorWrapperY, ... + void SetTensorWrapperx(const std::vector& x) { + for (const auto& eager_tensor : x) { + x_.emplace_back(egr::TensorWrapper(eager_tensor, true)); + } + } + + // SetAttributes + + private: + // TensorWrappers + std::vector x_; + + // Attributes +}; diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index a6f5a36e389..a3beb268cfa 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -41,7 +41,7 @@ from codegen_utils import AssertMessage, GetIndent # and this will be fixed in the futrue. inplace_check_blacklist = set(["assign_out_"]) # # --- Black Ops list that's NO NEED to apply backward code generation -black_ops_list = ["conv2d", "conv2d_grad", "conv2d_grad_grad"] +black_ops_list = ["conv2d", "conv2d_grad", "conv2d_grad_grad", "add_n"] ########### @@ -283,6 +283,7 @@ NODE_H_FILE_TEMPLATE = \ #pragma once #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" {} """ @@ -316,6 +317,7 @@ FORWARD_H_FILE_TEMPLATE = \ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/eager/to_static/run_program_op_func.h" #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" + {} {} """ @@ -1648,6 +1650,8 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): namespace = self.namespace for forward_api_contents in forward_api_list: + if forward_api_contents['api'] in black_ops_list: continue + backward_api_contents = self.GetBackwardAPIContents( forward_api_contents) if backward_api_contents is None: continue diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index b68418885ca..362c9606eba 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -871,49 +871,6 @@ std::tuple momentum_impl( ////////////////// Backward(grad) api impls ////////////////////// -// TODO(chenweihang): the original sum grad op can support higher-level -// differentiation, -// but if we use this impl, it will not support. We need to be able to reuse -// the autograd API here, which is not yet implemented -// TODO(chenweihang): we should support call generated api in custom api impl -void add_n_grad_impl(const std::vector& x, - const Tensor& out_grad, - std::vector x_grad) { - auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - - Backend kernel_backend = kernel_key.backend(); - DataLayout kernel_layout = kernel_key.layout(); - DataType kernel_data_type = kernel_key.dtype(); - - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "scale", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "add_n_grad API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << "add_n_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {}); - - auto dense_x_grad = SetKernelOutput(&x_grad); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const phi::DenseTensor&, - const phi::Scalar&, - float, - bool, - phi::DenseTensor*); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - - for (auto* dense_x_grad_t : dense_x_grad) { - phi::MetaTensor meta_out(dense_x_grad_t); - phi::UnchangedInferMeta(MakeMetaTensor(*dense_out_grad), &meta_out); - (*kernel_fn)( - *dev_ctx, *dense_out_grad, phi::Scalar(1.0), 0.0, true, dense_x_grad_t); - } -} - std::tuple batch_norm_impl( const Tensor& x, const Tensor& scale, diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index 627ff2aabf1..ef695580a07 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -116,10 +116,6 @@ std::tuple momentum_impl( ////////////////// Backward(grad) api impls ////////////////////// -void add_n_grad_impl(const std::vector& x, - const Tensor& out_grad, - std::vector x_grad); - void conv2d_grad_impl(const Tensor& input, const Tensor& filter, const Tensor& out_grad, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 4af32c7e4cf..f01598e6434 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -71,13 +71,6 @@ backward : add_double_grad inplace : (out_grad -> x_grad) -- backward_api : add_n_grad - forward : add_n (Tensor[] x) -> Tensor(out) - args : (Tensor[] x, Tensor out_grad) - output : Tensor[](x_grad){x.size()} - invoke : add_n_grad_impl(x, out_grad, x_grad) - no_need_buffer : x - - backward_api : add_triple_grad forward : add_double_grad (Tensor y, Tensor grad_out, Tensor grad_grad_x, Tensor grad_grad_y, int axis = -1) -> Tensor(grad_grad_out) args : (Tensor grad_grad_x, Tensor grad_grad_y, Tensor grad_grad_out_grad, int axis = -1) diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 9d1a4cf19eb..ad226878f7e 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -384,6 +384,29 @@ class API_Test_Add_n(unittest.TestCase): self.assertEqual( (input1.grad.numpy() == expected_grad_result).all(), True) + def test_add_n_and_add_and_grad(self): + with fluid.dygraph.guard(): + np_x = np.array([[1, 2, 3], [4, 5, 6]]) + np_y = [[7, 8, 9], [10, 11, 12]] + np_z = [[1, 1, 1], [1, 1, 1]] + x = paddle.to_tensor(np_x, dtype='float32', stop_gradient=False) + y = paddle.to_tensor(np_y, dtype='float32', stop_gradient=False) + z = paddle.to_tensor(np_z, dtype='float32') + + out1 = x + z + out2 = y + z + out = paddle.add_n([out1, out2]) + + dx, dy = paddle.grad([out], [x, y], create_graph=True) + + expected_out = np.array([[10., 12., 14.], [16., 18., 20.]]) + expected_dx = np.array([[1, 1, 1], [1, 1, 1]]) + expected_dy = np.array([[1, 1, 1], [1, 1, 1]]) + + self.assertTrue(np.allclose(out, expected_out)) + self.assertTrue(np.allclose(dx, expected_dx)) + self.assertTrue(np.allclose(dy, expected_dy)) + class TestRaiseSumError(unittest.TestCase): -- GitLab