From 227fa4083efa2a0aa902aa4edee18e7a367d0f2a Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Mon, 14 Mar 2022 16:59:55 +0800 Subject: [PATCH] Support custom op and paddle.autograd.bacward in eager (#40423) * eager, test=develop * fix bug, test=develop * eager, test=develop * merge legacy to fluid * eager, test=develop * eager, test=develop * Refactor TensorAdd func by template and remove gradient_accumulation in eager * Remove needless target name * eager, test=develop * eager, test=develop * Use overload instead of template * Remove legacy code * Remove legacy code * selectedrows, test=develop * Remove DataType test * eager, test=develop * eager, test=develop * support gan, test=develop * Using Tensor directly instead of using EagerTensor * support gradient_accumulation * make test_imperative_lod_tensor_to_selected_rows longer * make test_imperative_lod_tensor_to_selected_rows longer * refine code * ptb, test=develop * Rename all EagerTensor to Tensor * Rename some EagerTensor to Tensor * rename EagerTensor to EagerVariable * eager, test=develop * eager, test=develop * eager, test=develop * eager, test=develop * add more test * eager, test=develop * Support copiable selected rows and merge develop * save load, eager, test=develop * save load, eager, test=develop * refine, test=develop * remove useless _set_value method * refine, test=develop * refine, test=develop * revert static_runner, test=develop * EagerTensor to Tensor, test=develop * refine, test=develop * refine, test=develop * clear grad, test=develop * merge, develop * merge, develop * merge, test=develop * merge, test=develop * Support quant and part of slice * support legacy static save * extend slim tests time * remove imperative on inference * remove imperative on inference * merge develop * fix typo * fix typo * split slice related code into 2 part for imperative and eager * split slice from inference * split slice from inference * fix test_tensor_register_hook * support custom op in eager mode * fix inference deps error * split eager utils from custom operator * fix type match * fix typo Co-authored-by: Wang Huan Co-authored-by: Weilong Wu Co-authored-by: wanghuancoder --- paddle/fluid/eager/CMakeLists.txt | 5 +- paddle/fluid/eager/api/utils/global_utils.h | 22 +- paddle/fluid/eager/backward.cc | 30 +- .../eager/custom_operator/CMakeLists.txt | 1 + .../custom_operator/custom_operator_node.cc | 90 ++++++ .../custom_operator/custom_operator_node.h | 77 +++++ paddle/fluid/eager/grad_node_info.cc | 2 +- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/custom_operator.cc | 6 +- paddle/fluid/framework/custom_operator.h | 5 +- paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/eager_functions.cc | 277 +++++++++++++++++- paddle/fluid/pybind/eager_properties.cc | 2 +- paddle/fluid/pybind/eager_utils.cc | 21 +- paddle/fluid/pybind/eager_utils.h | 6 +- paddle/fluid/pybind/exception.cc | 4 +- paddle/fluid/pybind/pybind.cc | 62 +++- paddle/phi/api/ext/op_meta_info.h | 18 +- paddle/phi/api/lib/op_meta_info.cc | 17 +- paddle/phi/api/lib/tensor.cc | 4 +- paddle/phi/api/lib/tensor_method.cc | 1 + python/paddle/autograd/backward_mode.py | 27 +- .../fluid/dygraph/varbase_patch_methods.py | 2 +- .../fluid/tests/custom_op/custom_relu_op.cc | 2 + .../fluid/tests/custom_op/custom_relu_op.cu | 2 + .../tests/custom_op/test_custom_attrs_jit.py | 15 +- .../tests/custom_op/test_custom_concat.py | 15 +- .../fluid/tests/custom_op/test_custom_conj.py | 8 +- .../tests/custom_op/test_custom_linear.py | 8 +- .../custom_op/test_custom_raw_op_kernel_op.py | 6 - .../tests/custom_op/test_custom_relu_model.py | 31 +- .../custom_op/test_custom_relu_op_jit.py | 19 +- .../custom_op/test_custom_relu_op_setup.py | 8 +- .../custom_op/test_custom_simple_slice.py | 8 +- .../tests/custom_op/test_dispatch_jit.py | 9 +- .../tests/custom_op/test_multi_out_jit.py | 9 +- .../tests/unittests/test_custom_grad_input.py | 42 ++- .../tests/unittests/test_egr_python_api.py | 4 +- python/paddle/utils/code_gen/api.yaml | 1 + .../utils/cpp_extension/extension_utils.py | 43 ++- 40 files changed, 803 insertions(+), 109 deletions(-) create mode 100644 paddle/fluid/eager/custom_operator/CMakeLists.txt create mode 100644 paddle/fluid/eager/custom_operator/custom_operator_node.cc create mode 100644 paddle/fluid/eager/custom_operator/custom_operator_node.h diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index f9d1b70539..691a381405 100644 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -1,4 +1,5 @@ -set(eager_deps phi_api hook_utils tensor_utils utils global_utils backward phi_tensor tracer layer autograd_meta grad_node_info grad_tensor_holder accumulation_node) +set(eager_deps phi_api hook_utils tensor_utils utils global_utils backward phi_tensor tracer layer autograd_meta grad_node_info grad_tensor_holder accumulation_node custom_operator_node) + set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy) set(generated_deps final_dygraph_function final_dygraph_node dygraph_function dygraph_node) @@ -9,6 +10,8 @@ endif() add_subdirectory(api) add_subdirectory(accumulation) +add_subdirectory(custom_operator) + cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator) diff --git a/paddle/fluid/eager/api/utils/global_utils.h b/paddle/fluid/eager/api/utils/global_utils.h index 00578d9a35..a9a62fcd50 100644 --- a/paddle/fluid/eager/api/utils/global_utils.h +++ b/paddle/fluid/eager/api/utils/global_utils.h @@ -18,7 +18,7 @@ #include #include #include "paddle/fluid/imperative/tracer.h" - +#include "paddle/phi/api/ext/op_meta_info.h" namespace egr { class UniqueNameGenerator { @@ -70,6 +70,21 @@ class Controller { void SetInEagerMode(bool in_eager_mode) { in_eager_mode_ = in_eager_mode; } + const std::unordered_map>& + GetOpMetaInfoMap() { + return op_meta_info_map_; + } + + void MergeOpMetaInfoMap(const std::unordered_map< + std::string, std::vector>& map) { + op_meta_info_map_.insert(map.begin(), map.end()); + } + + std::unordered_map>>& + GetCustomEdgesSlotMap() { + return custom_edges_slot_map_; + } + private: Controller() = default; static Controller* controller_; @@ -77,6 +92,11 @@ class Controller { new paddle::imperative::Tracer()}; // TODO(jiabin): remove when we don't need imperative. bool in_eager_mode_{false}; + std::unordered_map> + op_meta_info_map_; + /* op_type : {{grad_outputs}, {grad_inputs}, {input}, {output}, {attrs}}*/ + std::unordered_map>> + custom_edges_slot_map_; DISABLE_COPY_AND_ASSIGN(Controller); }; diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 934497d7d1..603f93d9dd 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -112,7 +112,8 @@ void RunBackward(const std::vector& tensors, // Prepare GradTensorHolder if (!node_input_buffers_dict.count(grad_node)) { - VLOG(6) << "Create Value for grad input tensor " << i; + VLOG(6) << "Create Value for grad input tensor " << i + << " of grad node: " << grad_node->name(); node_input_buffers_dict[grad_node] = std::make_unique(grad_node->InputMeta()); } @@ -158,19 +159,23 @@ void RunBackward(const std::vector& tensors, VLOG(6) << "Run Backward"; while (!queue.empty()) { GradNodeBase* node = queue.front(); - queue.pop(); + if (queue.size() > 1 && node_in_degree_map[node] != 0) { + queue.pop(); + continue; + } + queue.pop(); // Run node: This is where Hook happens PADDLE_ENFORCE( node_input_buffers_dict.count(node), paddle::platform::errors::Fatal( - "Unable to find next node in the InputBuufer" + "Unable to find next node in the GradTensorHolder \n" "Trying to run Node without configuring its GradTensorHolder")); std::unique_ptr node_input_buffer = std::move(node_input_buffers_dict[node]); - VLOG(6) << "Run Backward Kernel with input_buffer"; + VLOG(6) << "Run Backward Kernel with GradTensorHolder"; // Run Pre Backward Node and get outputs std::vector> grad_output_tensors = (*node)(node_input_buffer->Buffers()); @@ -215,9 +220,8 @@ void RunBackward(const std::vector& tensors, if ((!grad_output_tensor.defined() || !grad_output_tensor.initialized())) { - VLOG(6) - << "We get grad_output_tensor with slot: " << i << ", rank: " << j - << " as uninitialized or undefined in both tensor and variable"; + VLOG(6) << "We get grad_output_tensor with slot: " << i + << ", rank: " << j << " as uninitialized or undefined tensor"; } VLOG(6) << "Get Edge and grad_output_tensor with slot: " << i << ", rank: " << j @@ -228,6 +232,8 @@ void RunBackward(const std::vector& tensors, const auto& input_meta = next_node->InputMeta(); auto grad_tensor_holder = std::make_unique(input_meta); + VLOG(6) << "Construct GradTensorHolder for grad node: " + << next_node->name(); node_input_buffers_dict[next_node] = std::move(grad_tensor_holder); } VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first @@ -237,10 +243,12 @@ void RunBackward(const std::vector& tensors, // Update queue node_in_degree_map[next_node]--; - PADDLE_ENFORCE(node_in_degree_map[next_node] >= 0, - paddle::platform::errors::Fatal( - "Detected in-degree value smaller than zero." - "Node's in-degree cannot be negative")); + PADDLE_ENFORCE( + node_in_degree_map[next_node] >= 0, + paddle::platform::errors::Fatal( + "Detected in-degree value smaller than zero. For Node: %s" + "Node's in-degree cannot be negative", + next_node->name())); if (node_in_degree_map[next_node] == 0) { queue.emplace(std::move(next_node)); } diff --git a/paddle/fluid/eager/custom_operator/CMakeLists.txt b/paddle/fluid/eager/custom_operator/CMakeLists.txt new file mode 100644 index 0000000000..ccc9a03a55 --- /dev/null +++ b/paddle/fluid/eager/custom_operator/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(custom_operator_node SRCS custom_operator_node.cc DEPS phi_tensor phi_api grad_node_info custom_operator op_meta_info) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc new file mode 100644 index 0000000000..48ac8c8358 --- /dev/null +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/custom_operator/custom_operator_node.h" +#include "paddle/fluid/framework/custom_operator.h" +#include "paddle/fluid/framework/op_meta_info_helper.h" +#include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace egr { +std::vector> RunCustomOpNode:: +operator()( + const std::vector>& grads) { + paddle::CustomOpKernelContext ctx; + auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs( + egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); + auto grad_outputs_names = paddle::framework::OpMetaInfoHelper::GetOutputs( + egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); + auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); + auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap(); + + std::vector> tmp_ins( + grad_inputs_name.size()); + VLOG(7) << " Prepare Backward inputs of grads with size: " << grads.size() + << ", whose grad_inputs_name size is: " << grad_inputs_name.size(); + for (size_t i = 0; i < grads.size(); i++) { + if (map[1].find(i) != map[1].end()) { + VLOG(7) << "Insert grad: " << i << " to grad_inputs: " << map[1][i]; + tmp_ins[map[1][i]] = grads[i]; + } + } + + for (auto it : fwd_outs) { + VLOG(7) << "Insert fwd_outs to grad_inputs: " << it.first; + tmp_ins[it.first] = RunCustomOpNode::Recover(&(it.second)); + } + + for (auto it : fwd_ins) { + VLOG(7) << "Insert fwd_ins to grad_inputs: " << it.first; + tmp_ins[it.first] = RunCustomOpNode::Recover(&(it.second)); + } + + VLOG(6) << "Prepare Grad inputs"; + for (const auto& in : tmp_ins) { + ctx.EmplaceBackInputs(in); + } + VLOG(6) << "Prepare Grad attrs"; + ctx.EmplaceBackAttrs(attrs_); + std::vector> outs( + GetEdges().size()); + std::vector> tmp_outs( + grad_outputs_names.size()); + VLOG(6) << "Prepare Grad outputs for size: " << grad_outputs_names.size(); + for (size_t i = 0; i < GetEdges().size(); i++) { + if (map[0].find(i) != map[0].end()) { + VLOG(7) << "Insert grad outputs: " << i + << " with size: " << GetEdges()[i].size() + << " to tmp_outputs: " << map[0][i]; + for (size_t j = 0; j < GetEdges()[i].size(); j++) { + outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */ + std::make_shared( + phi::DataType::UNDEFINED), + egr::Controller::Instance().GenerateUniqueName( + "custom_tmp_grad")); + } + tmp_outs[map[0][i]] = outs[i]; + } + } + for (size_t i = 0; i < tmp_outs.size(); i++) { + VLOG(7) << "Prepare grad outputs size: " << tmp_outs[i].size(); + ctx.EmplaceBackOutputs(tmp_outs[i]); + } + VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_; + + (*paddle::framework::OpMetaInfoHelper::GetKernelFn( + kernel_map.at(op_type_)[1]))(&ctx); + return outs; +} +} // namespace egr diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.h b/paddle/fluid/eager/custom_operator/custom_operator_node.h new file mode 100644 index 0000000000..e5ddef9c06 --- /dev/null +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.h @@ -0,0 +1,77 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/hooks.h" +#include "paddle/fluid/eager/tensor_wrapper.h" +#include "paddle/fluid/framework/custom_operator.h" +#include "paddle/utils/any.h" + +namespace egr { +class RunCustomOpNode : public GradNodeBase { + public: + // Constructor: configure fwd input tensors to grad node + explicit RunCustomOpNode(size_t bwd_in_slot_num, size_t bwd_out_slot_num, + const std::string& op_type) + : GradNodeBase(bwd_in_slot_num, bwd_out_slot_num), op_type_(op_type) { + VLOG(6) << "Construct RunCustomOpNode for op: " << op_type; + } + + ~RunCustomOpNode() override { + VLOG(6) << "Destruct RunCustomOpNode for op: " << op_type_; + } + + // Functor: perform backward computations + virtual std::vector> operator()( + const std::vector>& grads) + override; + + std::string name() { + return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_); + } + + static std::vector ConstructTensorWrapper( + const std::vector& fwd_var) { + std::vector res; + for (auto const& var : fwd_var) { + res.emplace_back(var); + } + return res; + } + + static std::vector Recover( + std::vector* fwd_var) { + std::vector res; + for (size_t i = 0; i < fwd_var->size(); i++) { + res.emplace_back(fwd_var->at(i).recover(nullptr)); + } + return res; + } + + void SetAttrs(const std::vector& attr) { attrs_ = attr; } + + public: + std::unordered_map> fwd_outs; + std::unordered_map> fwd_ins; + std::unordered_map grads2grad_in_map; + + private: + std::vector attrs_; + std::string op_type_{""}; +}; + +} // namespace egr diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 427be83c3b..7eb2902d93 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -25,7 +25,7 @@ #include "glog/logging.h" /** - * Implementation of GradNodeBase, Edge and InputBuffer. + * Implementation of GradNodeBase, Edge and GradTensorHolder. **/ namespace egr { diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index aa92a3b222..5dc3d9e89c 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -440,6 +440,7 @@ message(STATUS "branch: ${PADDLE_BRANCH}") configure_file(commit.h.in commit.h) cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper phi_tensor op_meta_info phi_api) + #cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) #cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index b9e3bee25f..478e39b99d 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/op_meta_info_helper.h" @@ -946,15 +947,16 @@ void RegisterOperatorWithMetaInfoMap( ////////////////////// User APIs /////////////////////// // load op api -void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { +const std::unordered_map>& +LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name); VLOG(3) << "load custom_op lib: " << dso_name; typedef OpMetaInfoMap& get_op_meta_info_map_t(); auto* get_op_meta_info_map = detail::DynLoad(handle, "PD_GetOpMetaInfoMap"); auto& op_meta_info_map = get_op_meta_info_map(); - RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle); + return op_meta_info_map.GetMap(); } } // namespace framework diff --git a/paddle/fluid/framework/custom_operator.h b/paddle/fluid/framework/custom_operator.h index 4310b56437..fef1e82a14 100644 --- a/paddle/fluid/framework/custom_operator.h +++ b/paddle/fluid/framework/custom_operator.h @@ -20,9 +20,9 @@ limitations under the License. */ namespace paddle { namespace framework { - // Load custom op api: register op after user compiled -void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name); +const std::unordered_map>& +LoadOpMetaInfoAndRegisterOp(const std::string& dso_name); // Register custom op api: register op directly void RegisterOperatorWithMetaInfoMap( @@ -31,6 +31,5 @@ void RegisterOperatorWithMetaInfoMap( // Interface for selective register custom op. void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, void* dso_handle = nullptr); - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 2e901f3bff..7b223f7ed2 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -351,7 +351,7 @@ if(WITH_PYTHON) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) cc_library(paddle_eager SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc - DEPS eager_api autograd_meta backward grad_node_info phi op_function_common final_dygraph_function final_dygraph_node dygraph_function dygraph_node accumulation_node global_utils utils python) + DEPS eager_api autograd_meta backward grad_node_info phi op_function_common final_dygraph_function final_dygraph_node dygraph_function dygraph_node accumulation_node global_utils utils python custom_operator custom_operator_node) add_dependencies(paddle_eager eager_codegen) add_dependencies(paddle_eager eager_op_function_generator_cmd) list(APPEND PYBIND_DEPS paddle_eager) diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 0b04dc7347..e110432c67 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -21,21 +21,25 @@ limitations under the License. */ #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/backward.h" +#include "paddle/fluid/eager/custom_operator/custom_operator_node.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/custom_operator.h" +#include "paddle/fluid/framework/op_meta_info_helper.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" +#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" - namespace paddle { namespace pybind { @@ -168,7 +172,276 @@ static PyObject* eager_api_read_next_tensor_list(PyObject* self, PyObject* args, EAGER_CATCH_AND_THROW_RETURN_NULL } +static void ConstructFwdAndBwdMap( + const std::vector& vec_map, + const std::string& op_type) { + auto& in_out_map = egr::Controller::Instance().GetCustomEdgesSlotMap(); + if (in_out_map.find(op_type) != in_out_map.end()) { + VLOG(7) << "Find Exist CustomEdgesSlotMap Skip >>>> "; + return; + } else { + VLOG(7) << "Construct CustomEdgesSlotMap "; + auto inputs_names = + paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[0]); + auto outputs_names = + paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[0]); + auto attrs_names = + paddle::framework::OpMetaInfoHelper::GetAttrs(vec_map[0]); + auto grad_outputs_names = + paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[1]); + auto grad_inputs_names = + paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[1]); + auto grad_attrs_names = + paddle::framework::OpMetaInfoHelper::GetAttrs(vec_map[1]); + std::vector> res(5); + in_out_map.insert({op_type, res}); + // Prepare pos map for grad_outputs + VLOG(7) << "Prepare pos map for grad_outputs"; + PADDLE_ENFORCE_LE( + grad_outputs_names.size(), inputs_names.size(), + paddle::platform::errors::InvalidArgument( + "Grad outputs num should be less equal than forward inputs num.")); + for (size_t i = 0; i < grad_outputs_names.size(); i++) { + size_t end = grad_outputs_names[i].find("@GRAD"); + PADDLE_ENFORCE_NE( + end, std::string::npos, + paddle::platform::errors::NotFound( + "All Grad outputs should be grad and we got %s is not grad var, " + "please check your op and change to fit the rule.", + grad_outputs_names[i])); + for (size_t j = 0; j < inputs_names.size(); j++) { + if (grad_outputs_names[i].substr(0, end) == inputs_names[j]) { + VLOG(7) << " ==== Custom Operator: " << op_type << "'s No." << j + << " inputs: " << inputs_names[j] << " related to No." << i + << " grad_outputs: " << grad_outputs_names[i]; + in_out_map[op_type][0][j] = i; + } + } + } + // Prepare pos map for grad_inputs + for (size_t i = 0; i < grad_inputs_names.size(); i++) { + size_t end = grad_inputs_names[i].find("@GRAD"); + if (end != std::string::npos) { + for (size_t j = 0; j < outputs_names.size(); j++) { + if (grad_inputs_names[i].substr(0, end) == outputs_names[j]) { + VLOG(7) << " ==== Custom Operator: " << op_type << "'s No." << j + << " outputs: " << outputs_names[j] << " related to No." + << i << " grad_inputs's grad: " << grad_inputs_names[i]; + in_out_map[op_type][1][j] = i; + } + } + } else { + if (std::find(outputs_names.begin(), outputs_names.end(), + grad_inputs_names[i]) != outputs_names.end()) { + for (size_t j = 0; j < outputs_names.size(); j++) { + if (grad_inputs_names[i] == outputs_names[j]) { + VLOG(7) << " ==== Custom Operator: " << op_type << "'s No." << j + << " outputs: " << outputs_names[j] << " related to No." + << i + << " grad_inputs fwd outputs: " << grad_inputs_names[i]; + in_out_map[op_type][2][j] = i; + } + } + } else { + for (size_t j = 0; j < inputs_names.size(); j++) { + if (grad_inputs_names[i] == inputs_names[j]) { + VLOG(7) << " ==== Custom Operator: " << op_type << "'s No." << j + << " inputs: " << inputs_names[j] << " related to No." + << i + << " grad_inputs fwd inputs: " << grad_inputs_names[i]; + in_out_map[op_type][3][j] = i; + } + } + } + } + } + + // Prepare pos map for grad attrs_ + for (size_t i = 0; i < grad_attrs_names.size(); i++) { + auto end = std::find(attrs_names.begin(), attrs_names.end(), + grad_attrs_names[i]); + PADDLE_ENFORCE_NE(end, attrs_names.end(), + paddle::platform::errors::NotFound( + "All Grad attrs should be one of forward attrs and " + "we got %s is not one of them, please check your " + "op and change to fit the rule.", + grad_attrs_names[i])); + for (size_t j = 0; j < attrs_names.size(); j++) { + if (grad_attrs_names[i] == attrs_names[j]) { + VLOG(7) << " ==== Custom Operator: " << op_type << "'s No." << j + << " attrs: " << attrs_names[j] << " related to No." << i + << " grad_attrs: " << grad_attrs_names[i]; + in_out_map[op_type][4][j] = i; + } + } + } + } +} + +static std::vector CastAttrsToTragetType( + const std::vector& src, + const std::vector& attrs_names) { + std::vector res; + PADDLE_ENFORCE_EQ(src.size(), attrs_names.size(), + paddle::platform::errors::InvalidArgument( + "We Expected same size of attrs and attrs_name list, " + "if u got this error indicate your custom op setting " + "%s attrs, but you just give %s", + attrs_names.size(), src.size())); + for (size_t i = 0; i < src.size(); i++) { + size_t end = attrs_names[i].find(": "); + std::string type_name = + attrs_names[i].substr(end + 2, attrs_names.size() - end - 2); + if (type_name == "int") { + if (src[i].type() == typeid(bool)) { + res.emplace_back(static_cast(paddle::any_cast(src[i]))); + } else if (src[i].type() == typeid(int)) { + res.emplace_back(src[i]); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Your No. %s attrs should only can be bool or int32, other type is " + "forbidden for now but we got %s. Check your code first please", + i, src[i].type().name())); + } + } else if (type_name == "int64_t") { + if (src[i].type() == typeid(bool)) { + res.emplace_back(static_cast(paddle::any_cast(src[i]))); + } else if (src[i].type() == typeid(int)) { + res.emplace_back(static_cast(paddle::any_cast(src[i]))); + } else if (src[i].type() == typeid(int64_t)) { + res.emplace_back(src[i]); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Your No. %s attrs should only can be bool or int32 or int64_t, " + "other type is forbidden for now but we got %s. Check your code " + "first please", + i, src[i].type().name())); + } + } else { + res.emplace_back(src[i]); + } + } + return res; +} + +static PyObject* eager_api_run_costum_op(PyObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + paddle::CustomOpKernelContext ctx = + CastPyArg2CustomOpKernelContext(PyTuple_GET_ITEM(args, 0), 0); + std::string op_type = CastPyArg2AttrString(PyTuple_GET_ITEM(args, 1), 1); + bool trace_backward = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 2), 2); + VLOG(7) << "Get things for python for Custom Op: " << op_type + << ", trace_backward is: " << trace_backward; + auto meta_info_map = egr::Controller::Instance().GetOpMetaInfoMap(); + PADDLE_ENFORCE_NE(meta_info_map.find(op_type), meta_info_map.end(), + paddle::platform::errors::NotFound( + "Can't find %s in Eager OpMetaInfoMap which should be " + "created by LoadOpMetaInfoAndRegisterOp, please make " + "sure you registered your op first and try again. ", + op_type)); + VLOG(7) << "Run Kernel of Custom Op: " << op_type; + std::vector res_attrs = CastAttrsToTragetType( + ctx.Attrs(), paddle::framework::OpMetaInfoHelper::GetAttrs( + meta_info_map.at(op_type)[0])); + ctx.EmplaceBackAttrs(res_attrs); + const auto& vec_map = meta_info_map.at(op_type); + (*paddle::framework::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); + + VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op"; + std::vector> ins_auto_grad_metas; + std::vector> outs_auto_grad_metas; + VLOG(7) << "We got slot num of ins is: " << ctx.InputRange().size(); + ins_auto_grad_metas.resize(ctx.InputRange().size()); + VLOG(7) << "We got slot num of outs is: " << ctx.OutputRange().size(); + outs_auto_grad_metas.resize(ctx.OutputRange().size()); + for (size_t i = 0; i < ctx.InputRange().size(); i++) { + ins_auto_grad_metas[i] = + egr::EagerUtils::nullable_autograd_meta(ctx.InputsBetween( + ctx.InputRangeAt(i).first, ctx.InputRangeAt(i).second)); + } + for (size_t i = 0; i < ctx.OutputRange().size(); i++) { + outs_auto_grad_metas[i] = + egr::EagerUtils::unsafe_autograd_meta(ctx.OutputsBetweeen( + ctx.OutputRangeAt(i).first, ctx.OutputRangeAt(i).second)); + } + bool require_any_grad = false; + for (size_t i = 0; i < ins_auto_grad_metas.size(); i++) { + require_any_grad = + require_any_grad || egr::EagerUtils::ComputeRequireGrad( + trace_backward, &(ins_auto_grad_metas[i])); + } + if (require_any_grad) { + VLOG(6) << " Construct Grad for Custom Op: " << op_type; + ConstructFwdAndBwdMap(vec_map, op_type); + for (size_t i = 0; i < outs_auto_grad_metas.size(); i++) { + egr::EagerUtils::PassStopGradient(false, &(outs_auto_grad_metas[i])); + } + auto grad_node = std::make_shared( + outs_auto_grad_metas.size(), ins_auto_grad_metas.size(), op_type); + auto slot_map = + egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type); + // Prepare Grad outputs + size_t no_grad_cnt = 0; + for (size_t i = 0; i < ins_auto_grad_metas.size(); i++) { + if (slot_map[0].find(i) != slot_map[0].end()) { + grad_node->SetGradOutMeta(&ins_auto_grad_metas[i], slot_map[0][i]); + grad_node->AddEdges(&ins_auto_grad_metas[i], slot_map[0][i]); + } else { + grad_node->SetGradOutMeta(&ins_auto_grad_metas[i], + ins_auto_grad_metas.size() - 1 - no_grad_cnt); + grad_node->AddEdges(&ins_auto_grad_metas[i], + ins_auto_grad_metas.size() - 1 - no_grad_cnt); + no_grad_cnt++; + } + } + // Prepare Grad inputs with grad of fwd outputs + for (size_t i = 0; i < outs_auto_grad_metas.size(); i++) { + egr::EagerUtils::SetOutRankWithSlot(&(outs_auto_grad_metas[i]), i); + egr::EagerUtils::SetHistory(&(outs_auto_grad_metas[i]), grad_node); + grad_node->SetGradInMeta(&(outs_auto_grad_metas[i]), i); + egr::EagerUtils::CheckAndRetainGrad(ctx.OutputsBetweeen( + ctx.OutputRangeAt(i).first, ctx.OutputRangeAt(i).second)); + } + + // Prepare Grad inputs with fwd outputs + for (auto it = slot_map[2].begin(); it != slot_map[2].end(); it++) { + VLOG(7) << "Prepare fwd_outs: " << it->first + << " to grad_inputs: " << it->second; + grad_node->fwd_outs[it->second] = + egr::RunCustomOpNode::ConstructTensorWrapper( + ctx.OutputsBetweeen(ctx.OutputRangeAt(it->first).first, + ctx.OutputRangeAt(it->first).second)); + } + + // Prepare Grad inputs with fwd inputs + for (auto it = slot_map[3].begin(); it != slot_map[3].end(); it++) { + VLOG(7) << "Prepare fwd_ins: " << it->first + << " to grad_inputs: " << it->second; + grad_node->fwd_ins[it->second] = + egr::RunCustomOpNode::ConstructTensorWrapper( + ctx.InputsBetween(ctx.InputRangeAt(it->first).first, + ctx.InputRangeAt(it->first).second)); + } + + auto attrs_names = paddle::framework::OpMetaInfoHelper::GetAttrs( + meta_info_map.at(op_type)[1]); + std::vector attrs(attrs_names.size()); + // Prepare attrs for Grad node + for (auto it = slot_map[4].begin(); it != slot_map[4].end(); it++) { + VLOG(7) << "Prepare fwd attrs: " << it->first + << " to grad_attrs: " << it->second; + attrs[it->second] = res_attrs[it->first]; + } + grad_node->SetAttrs(attrs); + } + Py_INCREF(Py_None); + return Py_None; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef variable_functions[] = { + // TODO(jiabin): Remove scale when we have final state tests {"scale", (PyCFunction)(void (*)(void))eager_api_scale, METH_VARARGS | METH_KEYWORDS, NULL}, {"_set_expected_place", @@ -179,6 +452,8 @@ PyMethodDef variable_functions[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"run_backward", (PyCFunction)(void (*)(void))eager_api_run_backward, METH_VARARGS | METH_KEYWORDS, NULL}, + {"_run_custom_op", (PyCFunction)(void (*)(void))eager_api_run_costum_op, + METH_VARARGS | METH_KEYWORDS, NULL}, {"tensor_copy", (PyCFunction)(void (*)(void))eager_api_tensor_copy, METH_VARARGS | METH_KEYWORDS, NULL}, {"read_next_tensor_list", diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index 2e1390cb96..2572866b8f 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -72,7 +72,7 @@ PyObject* tensor_properties_get_grad(TensorObject* self, void* closure) { EAGER_TRY VLOG(6) << "Get grad for tensor: " << self->tensor.name(); auto meta = egr::EagerUtils::nullable_autograd_meta(self->tensor); - if (meta) { + if (meta && meta->Grad().initialized()) { return ToPyObject(meta->Grad()); } else { Py_INCREF(Py_None); diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index f4e148cf8d..217edad0c0 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -27,10 +27,10 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/op_function_common.h" #include "paddle/fluid/pybind/tensor_py.h" +#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" - namespace paddle { namespace pybind { @@ -46,6 +46,7 @@ extern PyTypeObject* g_npuplace_pytype; extern PyTypeObject* g_cudapinnedplace_pytype; extern PyTypeObject* g_framework_tensor_pytype; extern PyTypeObject* g_framework_lodtensorarray_pytype; +extern PyTypeObject* g_custom_op_kernel_ctx_pytype; int TensorDtype2NumpyDtype(phi::DataType dtype) { switch (dtype) { @@ -184,7 +185,7 @@ paddle::experimental::Tensor CastPyArg2Tensor(PyObject* obj, ssize_t arg_pos) { } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be " - "EagerVariable, but got %s", + "Tensor, but got %s", arg_pos + 1, reinterpret_cast(obj->ob_type)->tp_name)); } } @@ -319,7 +320,7 @@ framework::Tensor CastPyArg2FrameworkTensor(PyObject* obj, ssize_t arg_pos) { } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be " - "EagerVariable, but got %s", + "DenseTensor, but got %s", arg_pos + 1, reinterpret_cast(obj->ob_type)->tp_name)); } } @@ -391,6 +392,19 @@ paddle::framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, return dtype; } +paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj, + ssize_t arg_pos) { + if (PyObject_IsInstance( + obj, reinterpret_cast(g_custom_op_kernel_ctx_pytype))) { + return ::pybind11::handle(obj).cast(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument (position %d) must be " + "one of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace), " + "but got %s", + arg_pos + 1, reinterpret_cast(obj->ob_type)->tp_name)); + } +} PyObject* ToPyObject(bool value) { if (value) { Py_INCREF(Py_True); @@ -928,6 +942,5 @@ paddle::experimental::DataType CastPyArg2DataType(PyObject* obj, framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos); return framework::TransToPhiDataType(type); } - } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 966a920377..2187555e1c 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -20,10 +20,10 @@ limitations under the License. */ #include "pybind11/pybind11.h" #include "pybind11/stl.h" namespace paddle { +class CustomOpKernelContext; namespace framework { class Scope; } - namespace pybind { typedef struct { @@ -40,6 +40,8 @@ int CastPyArg2AttrInt(PyObject* obj, ssize_t arg_pos); int64_t CastPyArg2AttrLong(PyObject* obj, ssize_t arg_pos); float CastPyArg2AttrFloat(PyObject* obj, ssize_t arg_pos); std::string CastPyArg2AttrString(PyObject* obj, ssize_t arg_pos); +paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj, + ssize_t arg_pos); paddle::experimental::Tensor CastPyArg2Tensor(PyObject* obj, ssize_t arg_pos); std::shared_ptr CastPyArg2VarBase(PyObject* obj, ssize_t arg_pos); @@ -52,6 +54,7 @@ std::vector CastPyArg2VectorOfTensorBase(PyObject* obj, std::vector CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos); framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, ssize_t arg_pos); + PyObject* ToPyObject(int value); PyObject* ToPyObject(bool value); PyObject* ToPyObject(int64_t value); @@ -138,6 +141,7 @@ std::vector GetTensorPtrListFromArgs( ssize_t arg_idx, bool dispensable = false); // end of Slice related methods + std::vector GetScopePtrListFromArgs( const std::string& op_type, const std::string& arg_name, PyObject* args, ssize_t arg_idx, bool dispensable); diff --git a/paddle/fluid/pybind/exception.cc b/paddle/fluid/pybind/exception.cc index 362a3e44fa..4f25a6f1a5 100644 --- a/paddle/fluid/pybind/exception.cc +++ b/paddle/fluid/pybind/exception.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/pybind/exception.h" - +#include "paddle/phi/api/ext/exception.h" namespace paddle { namespace pybind { @@ -122,6 +122,8 @@ void ThrowExceptionToPython(std::exception_ptr p) { PyErr_SetString(EnforceNotMetException, e.what()); break; } + } catch (const paddle::PD_Exception& e) { + PyErr_SetString(PyExc_OSError, e.what()); } } } // namespace pybind diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index ee6dce5dc2..21bbc7f3e3 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -164,6 +164,9 @@ limitations under the License. */ #include "paddle/fluid/pybind/fleet_py.h" #endif +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/pybind/eager_utils.h" +#include "paddle/phi/api/ext/op_meta_info.h" #include "pybind11/stl.h" DECLARE_bool(use_mkldnn); @@ -187,6 +190,7 @@ PyTypeObject *g_cudapinnedplace_pytype = nullptr; PyTypeObject *g_mluplace_pytype = nullptr; PyTypeObject *g_framework_tensor_pytype = nullptr; PyTypeObject *g_framework_lodtensorarray_pytype = nullptr; +PyTypeObject *g_custom_op_kernel_ctx_pytype = nullptr; bool IsCompiledWithCUDA() { #if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) @@ -757,6 +761,57 @@ PYBIND11_MODULE(core_noavx, m) { m.def("_promote_types_if_complex_exists", &paddle::framework::PromoteTypesIfComplexExists); + py::class_ custom_op_kernel_ctx( + m, "CustomOpKernelContext", R"DOC()DOC"); + g_custom_op_kernel_ctx_pytype = + reinterpret_cast(custom_op_kernel_ctx.ptr()); + custom_op_kernel_ctx.def(py::init<>()) + .def("add_inputs", + [](paddle::CustomOpKernelContext &self, const py::handle &input) { + PyObject *obj = input.ptr(); + if (PyList_Check(obj) || PyTuple_Check(obj)) { + self.EmplaceBackInputs( + std::move(CastPyArg2VectorOfTensor(obj, 1))); + } else { + self.EmplaceBackInput(std::move(CastPyArg2Tensor(obj, 1))); + } + }) + .def("add_outputs", + [](paddle::CustomOpKernelContext &self, py::handle &outputs) { + PyObject *obj = outputs.ptr(); + if (PyList_Check(obj) || PyTuple_Check(obj)) { + self.EmplaceBackOutputs( + std::move(CastPyArg2VectorOfTensor(obj, 1))); + } else { + self.EmplaceBackOutput(std::move(CastPyArg2Tensor(obj, 1))); + } + }) + .def("add_attr", [](paddle::CustomOpKernelContext &self, + bool attr) { self.EmplaceBackAttr(attr); }) + .def("add_attr", [](paddle::CustomOpKernelContext &self, + int attr) { self.EmplaceBackAttr(attr); }) + .def("add_attr", [](paddle::CustomOpKernelContext &self, + float attr) { self.EmplaceBackAttr(attr); }) + .def("add_attr", [](paddle::CustomOpKernelContext &self, + int64_t attr) { self.EmplaceBackAttr(attr); }) + .def("add_attr", + [](paddle::CustomOpKernelContext &self, const std::string &attr) { + self.EmplaceBackAttr(attr); + }) + .def("add_attr", + [](paddle::CustomOpKernelContext &self, + const std::vector &attr) { self.EmplaceBackAttr(attr); }) + .def("add_attr", + [](paddle::CustomOpKernelContext &self, + const std::vector &attr) { self.EmplaceBackAttr(attr); }) + .def("add_attr", + [](paddle::CustomOpKernelContext &self, + const std::vector &attr) { self.EmplaceBackAttr(attr); }) + .def("add_attr", [](paddle::CustomOpKernelContext &self, + const std::vector &attr) { + self.EmplaceBackAttr(attr); + }); + py::class_ framework_tensor(m, "Tensor", py::buffer_protocol()); g_framework_tensor_pytype = @@ -2827,10 +2882,11 @@ All parameter, weight, gradient are variables in Paddle. m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); - m.def("load_op_meta_info_and_register_op", - framework::LoadOpMetaInfoAndRegisterOp); + m.def("load_op_meta_info_and_register_op", [](const std::string dso_name) { + egr::Controller::Instance().MergeOpMetaInfoMap( + framework::LoadOpMetaInfoAndRegisterOp(dso_name)); + }); m.def("init_devices", []() { framework::InitDevices(); }); - m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_ascend", IsCompiledWithAscend); m.def("is_compiled_with_rocm", IsCompiledWithROCM); diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index 7601696293..88660449b6 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -86,19 +86,28 @@ class PADDLE_API CustomOpKernelContext { CustomOpKernelContext() = default; void EmplaceBackInput(Tensor&& input); - void EmplaceBackInputs(std::vector&& inputs); + void EmplaceBackInputs(const std::vector& inputs); void EmplaceBackOutput(Tensor&& output); - void EmplaceBackOutputs(std::vector&& outputs); + void EmplaceBackOutputs(const std::vector& outputs); void EmplaceBackAttr(paddle::any attr); - + void EmplaceBackAttrs(const std::vector& attrs) { + attrs_ = std::move(attrs); + } const std::pair& InputRangeAt(size_t idx) const; const std::pair& OutputRangeAt(size_t idx) const; const Tensor& InputAt(size_t idx) const; std::vector InputsBetween(size_t start, size_t end) const; - + const std::vector& Attrs() const { return attrs_; } + const std::vector>& InputRange() { + return input_range_; + } + const std::vector>& OutputRange() { + return output_range_; + } Tensor* MutableOutputAt(size_t idx); std::vector MutableOutputBetweeen(size_t start, size_t end); + std::vector OutputsBetweeen(size_t start, size_t end); std::vector* AllMutableOutput(); template @@ -552,7 +561,6 @@ class PADDLE_API OpMetaInfo { std::vector inputs_; std::vector outputs_; std::vector attrs_; - // 2. func info KernelFunc kernel_fn_{nullptr}; InferShapeFunc infer_shape_fn_{nullptr}; diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index 51d51c954d..14dba664c4 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -51,7 +51,8 @@ void CustomOpKernelContext::EmplaceBackInput(Tensor&& input) { input_range_.emplace_back(std::make_pair(index, index + 1)); } -void CustomOpKernelContext::EmplaceBackInputs(std::vector&& inputs) { +void CustomOpKernelContext::EmplaceBackInputs( + const std::vector& inputs) { size_t index = inputs_.size(); input_range_.emplace_back(std::make_pair(index, index + inputs.size())); inputs_.insert(inputs_.end(), @@ -65,7 +66,8 @@ void CustomOpKernelContext::EmplaceBackOutput(Tensor&& output) { output_range_.emplace_back(std::make_pair(index, index + 1)); } -void CustomOpKernelContext::EmplaceBackOutputs(std::vector&& outputs) { +void CustomOpKernelContext::EmplaceBackOutputs( + const std::vector& outputs) { size_t index = outputs_.size(); output_range_.emplace_back(std::make_pair(index, index + outputs.size())); outputs_.insert(outputs_.end(), @@ -75,6 +77,8 @@ void CustomOpKernelContext::EmplaceBackOutputs(std::vector&& outputs) { void CustomOpKernelContext::EmplaceBackAttr(paddle::any attr) { attrs_.emplace_back(std::move(attr)); + VLOG(7) << "attrs_ No." << attrs_.size() - 1 + << " has value of type: " << attrs_[attrs_.size() - 1].type().name(); } const Tensor& CustomOpKernelContext::InputAt(size_t idx) const { @@ -102,6 +106,15 @@ std::vector CustomOpKernelContext::MutableOutputBetweeen(size_t start, return rlt; } +std::vector CustomOpKernelContext::OutputsBetweeen(size_t start, + size_t end) { + std::vector rlt; + for (size_t i = start; i < end; ++i) { + rlt.emplace_back(outputs_.at(i)); + } + return rlt; +} + std::vector* CustomOpKernelContext::AllMutableOutput() { return &outputs_; } diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index 311dd0fc30..40174a505d 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -111,8 +111,8 @@ void Tensor::reshape(const std::vector &shape) { "touching underlying data, this requires the total size of " "the tensor to remain constant."; if (is_dense_tensor()) { - std::dynamic_pointer_cast(impl_)->set_meta( - phi::DenseTensorMeta(dtype(), phi::make_ddim(shape))); + std::dynamic_pointer_cast(impl_)->Resize( + phi::make_ddim(shape)); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support reshape operation on DenseTensor now.")); diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index aefa26952d..885e29b27f 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/lib/ext_compat_utils.h" +#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/tensor_base.h" diff --git a/python/paddle/autograd/backward_mode.py b/python/paddle/autograd/backward_mode.py index 36ca048c51..6fc6f7d3d4 100644 --- a/python/paddle/autograd/backward_mode.py +++ b/python/paddle/autograd/backward_mode.py @@ -81,15 +81,14 @@ def backward(tensors, grad_tensors=None, retain_graph=False): if isinstance(in_out_list, (list, tuple)): assert len(in_out_list) > 0, "{} connot be empyt".format(name) for each_var in in_out_list: - assert isinstance( - each_var, paddle. - Tensor), "Elements of {} must be paddle.Tensor".format(name) + assert isinstance(each_var, ( + paddle.Tensor, core.eager.Tensor + )), "Elements of {} must be paddle.Tensor".format(name) return in_out_list else: - assert isinstance( - in_out_list, - paddle.Tensor), "{} must be Tensor or list of Tensor".format( - name) + assert isinstance(in_out_list, ( + paddle.Tensor, core.eager.Tensor + )), "{} must be Tensor or list of Tensor".format(name) return [in_out_list] tensors = check_tensors(tensors, "tensors") @@ -105,10 +104,13 @@ def backward(tensors, grad_tensors=None, retain_graph=False): for each_tensor in grad_tensors: if each_tensor is not None: assert isinstance( - each_tensor, paddle.Tensor + each_tensor, (paddle.Tensor, core.eager.Tensor) ), "The argument 'grad_tensors' of paddle.autograd.backward is invalid, it can be 'None', 'paddle.Tensor' or 'list[None/paddle.Tensor]'." else: - grad_tensors = [None] * len(tensors) + if core._in_eager_mode(): + grad_tensors = [] + else: + grad_tensors = [None] * len(tensors) if len(grad_tensors) > 0: assert len(tensors) == len( @@ -116,5 +118,8 @@ def backward(tensors, grad_tensors=None, retain_graph=False): assert isinstance(retain_graph, bool), "retain_graph must be True or False" - core.dygraph_run_backward(tensors, grad_tensors, retain_graph, - framework._dygraph_tracer()) + if core._in_eager_mode(): + core.eager.run_backward(tensors, grad_tensors, retain_graph) + else: + core.dygraph_run_backward(tensors, grad_tensors, retain_graph, + framework._dygraph_tracer()) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 6843c0e4c3..2b67a20297 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -311,7 +311,7 @@ def monkey_patch_varbase(): """ if core._in_eager_mode(): - if not self.grad._is_initialized(): + if self.grad is None: return None # TODO(wanghuancoder) support SELECTED_ROWS return self.grad.numpy() diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc index c89990be34..acaf7cb742 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -153,6 +153,7 @@ PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward) .SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape)); void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { + out->reshape(x.shape()); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward", ([&] { relu_cpu_forward_kernel( @@ -164,6 +165,7 @@ void relu_cpu_backward_out(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out, paddle::Tensor* grad_x) { + grad_x->reshape(x.shape()); PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { relu_cpu_backward_kernel( grad_out.data(), diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index 33c5ede299..4bb773cdae 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -94,6 +94,7 @@ void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { int numel = x.size(); int block = 512; int grid = (numel + block - 1) / block; + out->reshape(x.shape()); PD_DISPATCH_FLOATING_AND_HALF_TYPES( x.type(), "relu_cuda_forward_kernel", ([&] { relu_cuda_forward_kernel<<>>( @@ -108,6 +109,7 @@ void relu_cuda_backward_out(const paddle::Tensor& x, int numel = out.size(); int block = 512; int grid = (numel + block - 1) / block; + grad_x->reshape(x.shape()); PD_DISPATCH_FLOATING_AND_HALF_TYPES( out.type(), "relu_cuda_backward_kernel", ([&] { relu_cuda_backward_kernel<<>>( diff --git a/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py index 1c9c6eedba..785bfc7422 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_attrs_jit.py @@ -20,6 +20,7 @@ import paddle from paddle.utils.cpp_extension import load, get_build_directory from utils import paddle_includes, extra_cc_args, extra_nvcc_args from paddle.utils.cpp_extension.extension_utils import run_cmd +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. @@ -53,7 +54,7 @@ class TestJitCustomAttrs(unittest.TestCase): self.int64_vec_attr = [10000000000, 10000000000, 10000000000] self.str_vec_attr = ["StrAttr", "StrAttr", "StrAttr"] - def test_attr_value(self): + def func_attr_value(self): x = paddle.ones([2, 2], dtype='float32') x.stop_gradient = False out = custom_attrs.attr_test( @@ -65,7 +66,12 @@ class TestJitCustomAttrs(unittest.TestCase): self.assertTrue(np.array_equal(x.numpy(), out.numpy())) - def test_const_attr_value(self): + def test_attr_value(self): + with _test_eager_guard(): + self.func_attr_value() + self.func_attr_value() + + def func_const_attr_value(self): x = paddle.ones([2, 2], dtype='float32') x.stop_gradient = False out = custom_attrs.const_attr_test( @@ -77,6 +83,11 @@ class TestJitCustomAttrs(unittest.TestCase): self.assertTrue(np.array_equal(x.numpy(), out.numpy())) + def test_const_attr_value(self): + with _test_eager_guard(): + self.func_const_attr_value() + self.func_const_attr_value() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_custom_concat.py b/python/paddle/fluid/tests/custom_op/test_custom_concat.py index 9049b604c9..62e61c5bc7 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_concat.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_concat.py @@ -21,6 +21,7 @@ import paddle.static as static from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_cc_args, extra_nvcc_args +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. @@ -116,7 +117,7 @@ class TestCustomConcatDynamicAxisJit(unittest.TestCase): "custom op {}: {},\n paddle api {}: {}".format(name, out, name, pd_out)) - def test_dynamic(self): + def func_dynamic(self): for dtype in self.dtypes: for axis in self.axises: out, grad_inputs = concat_dynamic(custom_ops.custom_concat, @@ -128,6 +129,11 @@ class TestCustomConcatDynamicAxisJit(unittest.TestCase): for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs): self.check_output(x_grad, pd_x_grad, "x_grad") + def test_dynamic(self): + with _test_eager_guard(): + self.func_dynamic() + self.func_dynamic() + def test_static(self): for dtype in self.dtypes: for axis in self.axises: @@ -140,7 +146,7 @@ class TestCustomConcatDynamicAxisJit(unittest.TestCase): self.check_output(x1_grad, pd_x1_grad, "x1_grad") self.check_output(x2_grad, pd_x2_grad, "x2_grad") - def test_dynamic_with_attr(self): + def func_dynamic_with_attr(self): for dtype in self.dtypes: for axis in self.axises: out, grad_inputs = concat_dynamic( @@ -153,6 +159,11 @@ class TestCustomConcatDynamicAxisJit(unittest.TestCase): for x_grad, pd_x_grad in zip(grad_inputs, pd_grad_inputs): self.check_output(x_grad, pd_x_grad, "x_grad") + def test_dynamic_with_attr(self): + with _test_eager_guard(): + self.func_dynamic_with_attr() + self.func_dynamic_with_attr() + def test_static_with_attr(self): for dtype in self.dtypes: for axis in self.axises: diff --git a/python/paddle/fluid/tests/custom_op/test_custom_conj.py b/python/paddle/fluid/tests/custom_op/test_custom_conj.py index 25c88ee6c6..5f3c107a9b 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_conj.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_conj.py @@ -21,6 +21,7 @@ import paddle.static as static from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_cc_args, extra_nvcc_args +from paddle.fluid.framework import _test_eager_guard # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. @@ -116,11 +117,16 @@ class TestCustomConjJit(unittest.TestCase): self.check_output(out, pd_out, "out") self.check_output(x_grad, pd_x_grad, "x's grad") - def test_dynamic(self): + def func_dynamic(self): for dtype in self.dtypes: np_input = np.random.random(self.shape).astype(dtype) self.run_dynamic(dtype, np_input) + def test_dynamic(self): + with _test_eager_guard(): + self.func_dynamic() + self.func_dynamic() + def test_static(self): for dtype in self.dtypes: np_input = np.random.random(self.shape).astype(dtype) diff --git a/python/paddle/fluid/tests/custom_op/test_custom_linear.py b/python/paddle/fluid/tests/custom_op/test_custom_linear.py index 0ba70eaa3e..811eedf1ed 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_linear.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_linear.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_cc_args, extra_nvcc_args +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. @@ -94,7 +95,7 @@ class TestCustomLinearJit(unittest.TestCase): self.np_bias) self.check_output(pten_out, pd_out, "pten_out") - def test_dynamic(self): + def func_dynamic(self): for dtype in self.dtypes: pten_out = linear_dynamic(custom_ops.pten_linear, dtype, self.np_x, self.np_weight, self.np_bias) @@ -102,6 +103,11 @@ class TestCustomLinearJit(unittest.TestCase): self.np_bias) self.check_output(pten_out, pd_out, "pten_out") + def test_dynamic(self): + with _test_eager_guard(): + self.func_dynamic() + self.func_dynamic() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py b/python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py index 207ea87974..4da99b1ea1 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py @@ -68,12 +68,6 @@ class TestCustomRawReluOp(unittest.TestCase): self.assertTrue(custom_raw_relu_op is not None) return custom_raw_relu_op(x) - def test_dygraph(self): - x = paddle.to_tensor(np.random.uniform(low=-1.0, high=1.0, size=[2, 3])) - y1 = self.custom_raw_relu(x) - y2 = paddle.nn.ReLU()(x) - self.assertTrue(np.array_equal(y1.numpy(), y2.numpy())) - def test_static(self): paddle.enable_static() shape = [2, 3] diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_model.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_model.py index dddb14eb78..81793f1391 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_model.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_model.py @@ -22,6 +22,7 @@ from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_cc_args, extra_nvcc_args, IS_MAC +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. @@ -98,7 +99,7 @@ class TestDygraphModel(unittest.TestCase): self.x_spec = paddle.static.InputSpec( shape=[None, self.in_dim], dtype='float32', name='x') - def test_train_eval(self): + def func_train_eval(self): for device in self.devices: # set device paddle.set_device(device) @@ -106,26 +107,34 @@ class TestDygraphModel(unittest.TestCase): # for train origin_relu_train_out = self.train_model(use_custom_op=False) custom_relu_train_out = self.train_model(use_custom_op=True) - custom_relu_dy2stat_train_out = self.train_model( - use_custom_op=True, dy2stat=True) # for to_static + # open this when dy2stat is ready for eager + if not _in_eager_mode(): + custom_relu_dy2stat_train_out = self.train_model( + use_custom_op=True, dy2stat=True) # for to_static + self.assertTrue( + np.array_equal(origin_relu_train_out, + custom_relu_dy2stat_train_out)) self.assertTrue( np.array_equal(origin_relu_train_out, custom_relu_train_out)) - self.assertTrue( - np.array_equal(origin_relu_train_out, - custom_relu_dy2stat_train_out)) # for eval origin_relu_eval_out = self.eval_model(use_custom_op=False) custom_relu_eval_out = self.eval_model(use_custom_op=True) - custom_relu_dy2stat_eval_out = self.eval_model( - use_custom_op=True, dy2stat=True) # for to_static + if not _in_eager_mode(): + custom_relu_dy2stat_eval_out = self.eval_model( + use_custom_op=True, dy2stat=True) # for to_static + self.assertTrue( + np.array_equal(origin_relu_eval_out, + custom_relu_dy2stat_eval_out)) self.assertTrue( np.array_equal(origin_relu_eval_out, custom_relu_eval_out)) - self.assertTrue( - np.array_equal(origin_relu_eval_out, - custom_relu_dy2stat_eval_out)) + + def test_train_eval(self): + with _test_eager_guard(): + self.func_train_eval() + self.func_train_eval() def train_model(self, use_custom_op=False, dy2stat=False): # reset random seed diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 407eb342ba..a747d10823 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -20,7 +20,7 @@ from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_cc_args, extra_nvcc_args, IS_WINDOWS, IS_MAC from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static - +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. file = '{}\\custom_relu_module_jit\\custom_relu_module_jit.pyd'.format( @@ -75,7 +75,7 @@ class TestJITLoad(unittest.TestCase): "custom op out: {},\n paddle api out: {}".format( out, pd_out)) - def test_dynamic(self): + def func_dynamic(self): for device in self.devices: for dtype in self.dtypes: if device == 'cpu' and dtype == 'float16': @@ -95,8 +95,14 @@ class TestJITLoad(unittest.TestCase): "custom op x grad: {},\n paddle api x grad: {}".format( x_grad, pd_x_grad)) - def test_exception(self): + def test_dynamic(self): + with _test_eager_guard(): + self.func_dynamic() + self.func_dynamic() + + def func_exception(self): caught_exception = False + # if not _in_eager_mode(): try: x = np.random.uniform(-1, 1, [4, 8]).astype('int32') custom_relu_dynamic(custom_module.custom_relu, 'cpu', 'int32', x) @@ -114,11 +120,11 @@ class TestJITLoad(unittest.TestCase): "python/paddle/fluid/tests/custom_op/custom_relu_op.cc" in str(e)) self.assertTrue(caught_exception) - caught_exception = False # MAC-CI don't support GPU if IS_MAC: return + # if not _in_eager_mode(): try: x = np.random.uniform(-1, 1, [4, 8]).astype('int32') custom_relu_dynamic(custom_module.custom_relu, 'gpu', 'int32', x) @@ -132,6 +138,11 @@ class TestJITLoad(unittest.TestCase): str(e)) self.assertTrue(caught_exception) + def test_exception(self): + with _test_eager_guard(): + self.func_exception() + self.func_exception() + def test_load_multiple_module(self): custom_module = load( name='custom_conj_jit', diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index 0af0aa1646..7c61e11a18 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -21,6 +21,7 @@ import paddle.static as static import subprocess import numpy as np from paddle.utils.cpp_extension.extension_utils import run_cmd +from paddle.fluid.framework import _test_eager_guard def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): @@ -216,7 +217,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): "custom op out: {},\n paddle api out: {}".format( out, pd_out)) - def test_dynamic(self): + def func_dynamic(self): for device in self.devices: for dtype in self.dtypes: if device == 'cpu' and dtype == 'float16': @@ -236,6 +237,11 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): "custom op x grad: {},\n paddle api x grad: {}".format( x_grad, pd_x_grad)) + def test_dynamic(self): + with _test_eager_guard(): + self.func_dynamic() + self.func_dynamic() + def test_static_save_and_load_inference_model(self): paddle.enable_static() np_data = np.random.random((1, 1, 28, 28)).astype("float32") diff --git a/python/paddle/fluid/tests/custom_op/test_custom_simple_slice.py b/python/paddle/fluid/tests/custom_op/test_custom_simple_slice.py index c60bac4060..f68a37b1a2 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_simple_slice.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_simple_slice.py @@ -20,6 +20,7 @@ import paddle from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_cc_args, extra_nvcc_args +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. @@ -39,7 +40,7 @@ custom_ops = load( class TestCustomSimpleSliceJit(unittest.TestCase): - def test_slice_output(self): + def func_slice_output(self): np_x = np.random.random((5, 2)).astype("float32") x = paddle.to_tensor(np_x) custom_op_out = custom_ops.custom_simple_slice(x, 2, 3) @@ -48,6 +49,11 @@ class TestCustomSimpleSliceJit(unittest.TestCase): np.array_equal(custom_op_out, np_out), "custom op: {},\n numpy: {}".format(np_out, custom_op_out.numpy())) + def test_slice_output(self): + with _test_eager_guard(): + self.func_slice_output() + self.func_slice_output() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py index 12e9f50a5e..0d2cb941ea 100644 --- a/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_dispatch_jit.py @@ -19,7 +19,7 @@ import numpy as np from paddle.utils.cpp_extension import load, get_build_directory from utils import paddle_includes, extra_cc_args from paddle.utils.cpp_extension.extension_utils import run_cmd - +from paddle.fluid.framework import _test_eager_guard # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. file = '{}\\dispatch_op\\dispatch_op.pyd'.format(get_build_directory()) @@ -39,7 +39,7 @@ class TestJitDispatch(unittest.TestCase): def setUp(self): paddle.set_device('cpu') - def run_dispatch_test(self, func, dtype): + def run_dispatch_test_impl(self, func, dtype): np_x = np.ones([2, 2]).astype(dtype) x = paddle.to_tensor(np_x) out = func(x) @@ -50,6 +50,11 @@ class TestJitDispatch(unittest.TestCase): np.array_equal(np_x, np_out), "custom op x: {},\n custom op out: {}".format(np_x, np_out)) + def run_dispatch_test(self, func, dtype): + with _test_eager_guard(): + self.run_dispatch_test_impl(func, dtype) + self.run_dispatch_test_impl(func, dtype) + def test_dispatch_integer(self): dtypes = ["int32", "int64", "int8", "uint8", "int16"] for dtype in dtypes: diff --git a/python/paddle/fluid/tests/custom_op/test_multi_out_jit.py b/python/paddle/fluid/tests/custom_op/test_multi_out_jit.py index 97b37498c4..4fc9270b0f 100644 --- a/python/paddle/fluid/tests/custom_op/test_multi_out_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_multi_out_jit.py @@ -22,7 +22,7 @@ from paddle.utils.cpp_extension import load from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension.extension_utils import run_cmd from utils import paddle_includes, extra_cc_args - +from paddle.fluid.framework import _test_eager_guard # Because Windows don't use docker, the shared lib already exists in the # cache dir, it will not be compiled again unless the shared lib is removed. file = '{}\\multi_out_jit\\multi_out_jit.pyd'.format(get_build_directory()) @@ -84,7 +84,7 @@ class TestMultiOutputDtypes(unittest.TestCase): self.check_multi_outputs(res) paddle.disable_static() - def test_dynamic(self): + def func_dynamic(self): for device in self.devices: for dtype in self.dtypes: paddle.set_device(device) @@ -95,6 +95,11 @@ class TestMultiOutputDtypes(unittest.TestCase): self.assertTrue(len(outs) == 3) self.check_multi_outputs(outs, True) + def test_dynamic(self): + with _test_eager_guard(): + self.func_dynamic() + self.func_dynamic() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_custom_grad_input.py b/python/paddle/fluid/tests/unittests/test_custom_grad_input.py index bc280a0189..83a25b7162 100644 --- a/python/paddle/fluid/tests/unittests/test_custom_grad_input.py +++ b/python/paddle/fluid/tests/unittests/test_custom_grad_input.py @@ -20,6 +20,7 @@ import numpy as np import paddle import paddle.fluid.dygraph as dg from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard class TestTensorBackward(unittest.TestCase): @@ -29,7 +30,7 @@ class TestTensorBackward(unittest.TestCase): if paddle.is_compiled_with_cuda(): self._places.append(paddle.CUDAPlace(0)) - def test_tensor_backward(self): + def func_tensor_backward(self): for dtype in self._dtypes: x = np.random.random([2, 100]).astype(dtype) y = np.random.random([100, 2]).astype(dtype) @@ -48,6 +49,11 @@ class TestTensorBackward(unittest.TestCase): self.assertTrue(np.allclose(x_grad, x_tensor.grad.numpy())) + def test_tensor_backward(self): + with _test_eager_guard(): + self.func_tensor_backward() + self.func_tensor_backward() + class TestBackwardAPI(unittest.TestCase): def setUp(self): @@ -56,7 +62,7 @@ class TestBackwardAPI(unittest.TestCase): if paddle.is_compiled_with_cuda(): self._places.append(paddle.CUDAPlace(0)) - def test_backward_api(self): + def func_backward_api(self): for dtype in self._dtypes: x = np.random.random([2, 2]).astype(dtype) y = np.random.random([2, 2]).astype(dtype) @@ -78,7 +84,12 @@ class TestBackwardAPI(unittest.TestCase): self.assertTrue( np.allclose(x_grad * 2, x_tensor.grad.numpy())) - def test_backward_single_tensor(self): + def test_backward_api(self): + with _test_eager_guard(): + self.func_backward_api() + self.func_backward_api() + + def func_backward_single_tensor(self): for dtype in self._dtypes: x = np.random.random([2, 2]).astype(dtype) y = np.random.random([2, 2]).astype(dtype) @@ -97,7 +108,12 @@ class TestBackwardAPI(unittest.TestCase): self.assertTrue(np.allclose(x_grad, x_tensor.grad.numpy())) - def test_backward_none_grad_tensor(self): + def test_backward_single_tensor(self): + with _test_eager_guard(): + self.func_backward_single_tensor() + self.func_backward_single_tensor() + + def func_backward_none_grad_tensor(self): for dtype in self._dtypes: x = np.random.random([2, 2]).astype(dtype) y = np.random.random([2, 2]).astype(dtype) @@ -115,7 +131,12 @@ class TestBackwardAPI(unittest.TestCase): self.assertTrue(np.allclose(x_grad, x_tensor.grad.numpy())) - def test_backward_accumulator_with_init_grad(self): + def test_backward_none_grad_tensor(self): + with _test_eager_guard(): + self.func_backward_none_grad_tensor() + self.func_backward_none_grad_tensor() + + def func_backward_accumulator_with_init_grad(self): for dtype in self._dtypes: x = np.random.random([10, ]).astype(dtype) y_grad = np.random.random([10, ]).astype(dtype) @@ -134,11 +155,14 @@ class TestBackwardAPI(unittest.TestCase): y = x**2 z = x**3 - x_grad = 2 * x_tensor * ( - y_grad_tensor + 3 * y_tensor * y_tensor * z_grad_tensor) + x_grad = 2 * x * (y_grad + 3 * y * y * z_grad) - self.assertTrue( - np.allclose(x_grad.numpy(), x_tensor.grad.numpy())) + self.assertTrue(np.allclose(x_grad, x_tensor.grad.numpy())) + + def test_backward_accumulator_with_init_grad(self): + with _test_eager_guard(): + self.func_backward_accumulator_with_init_grad() + self.func_backward_accumulator_with_init_grad() if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_egr_python_api.py b/python/paddle/fluid/tests/unittests/test_egr_python_api.py index 9744cda629..27aec284de 100644 --- a/python/paddle/fluid/tests/unittests/test_egr_python_api.py +++ b/python/paddle/fluid/tests/unittests/test_egr_python_api.py @@ -50,7 +50,7 @@ class EagerScaleTestCase(unittest.TestCase): data_eager.retain_grads() out_eager = core.eager.scale(data_eager, 1.0, 0.9, True, True) - self.assertFalse(data_eager.grad._is_initialized()) + self.assertIsNone(data_eager.grad) out_eager.backward(grad_eager, False) self.assertTrue(data_eager.grad._is_initialized()) self.assertTrue(np.array_equal(data_eager.grad.numpy(), input_data)) @@ -72,7 +72,7 @@ class EagerScaleTestCase(unittest.TestCase): data_eager.retain_grads() out_eager = core.eager.scale(data_eager, 1.0, 0.9, True, True) - self.assertFalse(data_eager.grad._is_initialized()) + self.assertIsNone(data_eager.grad) with self.assertRaisesRegexp( AssertionError, "The type of grad_tensor must be paddle.Tensor"): diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 6c27d465cb..aac68efc59 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -158,6 +158,7 @@ param : [x] kernel : func : scale, scale_sr + inplace : (x -> out) - api : sign args : (Tensor x) diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 853a98a62b..b0a5d37a53 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -146,6 +146,9 @@ def custom_write_stub(resource, pyfile): import types import paddle + cur_dir = os.path.dirname(os.path.abspath(__file__)) + so_path = os.path.join(cur_dir, "{resource}") + def inject_ext_module(module_name, api_names): if module_name in sys.modules: return sys.modules[module_name] @@ -157,9 +160,6 @@ def custom_write_stub(resource, pyfile): return new_module def __bootstrap__(): - cur_dir = os.path.dirname(os.path.abspath(__file__)) - so_path = os.path.join(cur_dir, "{resource}") - assert os.path.exists(so_path) # load custom op shared library with abs path @@ -169,6 +169,7 @@ def custom_write_stub(resource, pyfile): __bootstrap__() {custom_api} + """).lstrip() # Parse registerring op information @@ -900,7 +901,7 @@ def _generate_python_module(module_name, # delete the temp file before exit python process atexit.register(lambda: remove_if_exit(api_file)) - # write into .py file with RWLock + # write into .py file with RWLockc api_content = [_custom_api_content(op_name) for op_name in op_names] with open(api_file, 'w') as f: f.write('\n\n'.join(api_content)) @@ -911,13 +912,15 @@ def _generate_python_module(module_name, def _custom_api_content(op_name): - params_str, ins_str, attrs_str, outs_str = _get_api_inputs_str(op_name) - + params_str, ins_str, attrs_str, outs_str, in_names, attrs_names = _get_api_inputs_str( + op_name) + lower_in_names = [p.split("@")[0].lower() for p in in_names] API_TEMPLATE = textwrap.dedent(""" - from paddle.fluid.core import VarBase - from paddle.fluid.framework import in_dygraph_mode, _dygraph_tracer + import paddle.fluid.core as core + from paddle.fluid.core import VarBase, CustomOpKernelContext + from paddle.fluid.framework import in_dygraph_mode, _dygraph_tracer, _in_eager_mode from paddle.fluid.layer_helper import LayerHelper - + def {op_name}({inputs}): # prepare inputs and outputs ins = {ins} @@ -928,9 +931,20 @@ def _custom_api_content(op_name): # The output variable's dtype use default value 'float32', # and the actual dtype of output variable will be inferred in runtime. if in_dygraph_mode(): - for out_name in out_names: - outs[out_name] = VarBase() - _dygraph_tracer().trace_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) + if _in_eager_mode(): + ctx = CustomOpKernelContext() + for i in {in_names}: + ctx.add_inputs(i) + for j in {attr_names}: + ctx.add_attr(j) + for out_name in out_names: + outs[out_name] = core.eager.Tensor() + ctx.add_outputs(outs[out_name]) + core.eager._run_custom_op(ctx, "{op_name}", True) + else: + for out_name in out_names: + outs[out_name] = VarBase() + _dygraph_tracer().trace_op(type="{op_name}", inputs=ins, outputs=outs, attrs=attrs) else: helper = LayerHelper("{op_name}", **locals()) for out_name in out_names: @@ -949,6 +963,9 @@ def _custom_api_content(op_name): inputs=params_str, ins=ins_str, attrs=attrs_str, + # "[x, y, z]"" + in_names="[" + ",".join(lower_in_names) + "]", + attr_names="[" + ",".join(attrs_names) + "]", out_names=outs_str) return api_content @@ -996,7 +1013,7 @@ def _get_api_inputs_str(op_name): ]) # e.g: ['Out', 'Index'] outs_str = "[%s]" % ','.join(["'{}'".format(name) for name in out_names]) - return params_str, ins_str, attrs_str, outs_str + return params_str, ins_str, attrs_str, outs_str, in_names, attr_names def _write_setup_file(name, -- GitLab