From 888a30c95a7be2e591b6f65043ec21ca25be8493 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Mon, 27 Mar 2023 15:13:06 +0800 Subject: [PATCH] Automatically generate 'assign' operator (#51940) * support assign op * support assign infer_var_type * change code according to review * change code according to review * only save 'get_infer_var_type_func' * rest file mode --- paddle/fluid/operators/CMakeLists.txt | 4 +- paddle/fluid/operators/assign_op.cc | 152 ------------------ paddle/fluid/operators/generator/filters.py | 14 ++ .../fluid/operators/generator/generate_op.py | 2 + .../operators/generator/generate_sparse_op.py | 2 + .../generator/get_expected_kernel_func.cc | 18 +++ .../generator/get_expected_kernel_func.h | 4 + .../generator/templates/operator_utils.c.j2 | 9 ++ paddle/fluid/operators/unity_build_rule.cmake | 1 - paddle/phi/api/yaml/op_compat.yaml | 10 ++ paddle/phi/api/yaml/static_backward.yaml | 7 + paddle/phi/api/yaml/static_ops.yaml | 11 ++ paddle/phi/kernels/assign_kernel.cc | 12 +- .../kernels/selected_rows/assign_kernel.cc | 12 +- 14 files changed, 97 insertions(+), 161 deletions(-) delete mode 100644 paddle/fluid/operators/assign_op.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 39d3515899c..c56bb972b24 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -169,7 +169,7 @@ if (WITH_ASCEND) endif() if (WITH_ASCEND_CL) - cc_test(assign_op_npu_test SRCS assign_op_npu_test.cc DEPS assign_op) + cc_test(assign_op_npu_test SRCS assign_op_npu_test.cc DEPS generated_static_op) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner) endif() @@ -185,7 +185,7 @@ set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies") cc_test(test_common_infer_shape_functions SRCS test_common_infer_shape_functions.cc DEPS common_infer_shape_functions ${COMMON_OP_DEPS} activation_op elementwise_add_op softmax_op softmax) cc_test(gather_test SRCS gather_test.cc DEPS tensor) -cc_test(assign_op_test SRCS assign_op_test.cc DEPS assign_op) +cc_test(assign_op_test SRCS assign_op_test.cc DEPS generated_static_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function) cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc deleted file mode 100644 index ebd4fbf491e..00000000000 --- a/paddle/fluid/operators/assign_op.cc +++ /dev/null @@ -1,152 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/assign_op.h" - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" -#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" -#include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" -namespace paddle { -namespace framework { -class OpDesc; -class Variable; -} // namespace framework -namespace imperative { -class OpBase; -} // namespace imperative -} // namespace paddle - -namespace paddle { -namespace operators { - -class AssignOp : public framework::OperatorWithKernel { - public: - AssignOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - protected: - phi::KernelKey GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const phi::KernelKey &expected_kernel_type) const override { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - tensor.layout(), - expected_kernel_type.dtype()); - } - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - const framework::Variable *var = ctx.InputVar("X"); - if (var->IsType()) { - auto t_arr = var->Get(); - // NOTE(liym27): Support an empty tensor array as Input. - // And set the kernel type is float. - if (t_arr.size() == 0) { - return phi::KernelKey(framework::proto::VarType::FP32, - ctx.device_context().GetPlace()); - } - } - - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context().GetPlace()); - } -}; - -class AssignInferVarType : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - ctx->SyncTypeAndDataType("X", "Out"); - } -}; - -class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "X", - "(phi::DenseTensor, SelectedRows or phi::DenseTensorArray) The input " - "variable " - "could be phi::DenseTensor, SelectedRows or phi::DenseTensorArray.") - .AsDispensable(); - AddOutput("Out", - "(phi::DenseTensor, SelectedRows or phi::DenseTensorArray) The " - "type of output " - "is the same as input X."); - AddComment(R"DOC(Assign Operator - -Out = X, when type in [phi::DenseTensor/SelectedRows/phi::DenseTensorArray] -raise error if the type is not listed above. -)DOC"); - } -}; - -template -class AssignGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("assign"); - op->SetInput("X", this->OutputGrad("Out")); - op->SetOutput("Out", this->InputGrad("X")); - } -}; - -class AssignCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { - using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; - - public: - void Apply() override { - paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); - paddle::Tensor input_grad = this->GetSingleInputGrad("X"); - - auto dx_ptr = this->GetOutputPtr(&input_grad); - std::string dx_name = this->GetOutputName(input_grad); - - VLOG(6) << "Running assign_grad composite func"; - prim::assign_grad(out_grad, dx_ptr); - this->RecoverOutputName(input_grad, dx_name); - } -}; - -DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -DECLARE_INFER_SHAPE_FUNCTOR(assign, - AssignInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); -REGISTER_OPERATOR(assign, - ops::AssignOp, - ops::AssignCompositeGradOpMaker, - ops::AssignGradMaker, - ops::AssignGradMaker, - ops::AssignOpProtoMaker, - ops::AssignOpInplaceInferer, - ops::AssignInferVarType, - AssignInferShapeFunctor); diff --git a/paddle/fluid/operators/generator/filters.py b/paddle/fluid/operators/generator/filters.py index 6520a2e0751..56f87f3c1de 100644 --- a/paddle/fluid/operators/generator/filters.py +++ b/paddle/fluid/operators/generator/filters.py @@ -30,6 +30,20 @@ from type_mapping import ( ) +def get_infer_var_type_func(op_name): + if op_name == "assign": + return f""" + class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference {{ + public: + void operator()(framework::InferVarTypeContext *ctx) const override {{ + ctx->SyncTypeAndDataType("X", "Out"); + }} +}}; +""" + else: + return None + + def quote(s): return '"{}"'.format(s) diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index 223e6b714ae..ac788f5f931 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -23,6 +23,7 @@ from filters import ( cartesian_prod_mapping, delete_last_underline, find_optinal_inputs_name, + get_infer_var_type_func, to_composite_grad_opmaker_name, to_input_name, to_int_array_tensor_name, @@ -66,6 +67,7 @@ env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name env.filters["to_variable_names"] = to_variable_names +env.filters["get_infer_var_type_func"] = get_infer_var_type_func env.filters["assert_dense_or_sr"] = assert_dense_or_sr env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name env.tests["base_op"] = is_base_op diff --git a/paddle/fluid/operators/generator/generate_sparse_op.py b/paddle/fluid/operators/generator/generate_sparse_op.py index cee478e0df6..2635f0f67a1 100644 --- a/paddle/fluid/operators/generator/generate_sparse_op.py +++ b/paddle/fluid/operators/generator/generate_sparse_op.py @@ -21,6 +21,7 @@ from filters import ( assert_dense_or_sr, cartesian_prod_mapping, find_optinal_inputs_name, + get_infer_var_type_func, to_composite_grad_opmaker_name, to_input_name, to_int_array_tensor_name, @@ -67,6 +68,7 @@ env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name env.filters["to_variable_names"] = to_variable_names +env.filters["get_infer_var_type_func"] = get_infer_var_type_func env.tests["base_op"] = is_base_op env.tests["composite_op"] = is_composite_op env.tests["vec"] = is_vec diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.cc b/paddle/fluid/operators/generator/get_expected_kernel_func.cc index db963aa27ae..79a17940280 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.cc +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.cc @@ -101,6 +101,24 @@ phi::KernelKey GetReduceGradExpectedKernelType( return phi::KernelKey(input_data_type, ctx.GetPlace()); } +phi::KernelKey GetAssignExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + const framework::Variable* var = ctx.InputVar("X"); + if (var->IsType()) { + auto t_arr = var->Get(); + // NOTE(liym27): Support an empty tensor array as Input. + // And set the kernel type is float. + if (t_arr.size() == 0) { + return phi::KernelKey(framework::proto::VarType::FP32, + ctx.device_context().GetPlace()); + } + } + return phi::KernelKey( + op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context().GetPlace()); +} + phi::KernelKey GetSgdExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr) { diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.h b/paddle/fluid/operators/generator/get_expected_kernel_func.h index a5883f44d7c..f360c0d4b08 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.h +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.h @@ -28,6 +28,10 @@ phi::KernelKey GetReduceGradExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); +phi::KernelKey GetAssignExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + phi::KernelKey GetSgdExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index e252e42ca7b..83939edd007 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -438,6 +438,11 @@ class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKerne {%- endif %} }; +{% set infer_var_type_func_str = op["op_name"] | get_infer_var_type_func %} +{% if infer_var_type_func_str is not none %} +{{infer_var_type_func_str}} +{% endif %} + DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor, PD_INFER_META(phi::{{op["infer_meta"]["func"]}})); {# inplace inferer #} @@ -475,6 +480,10 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, {% if op is supports_inplace %}{# inplace#} ops::{{name | to_pascal_case}}InplaceInferer, {% endif %} +{% set infer_var_type_func_str = op["op_name"] | get_infer_var_type_func %} +{% if infer_var_type_func_str is not none %} + ops::{{name | to_pascal_case}}InferVarType, +{% endif %} {% if "backward_composite" in op and op["backward_composite"] is not none %} ops::{{op["backward_composite"] | to_composite_grad_opmaker_name}}, {% endif %} diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index f1d629b5c86..d713696d65b 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -15,7 +15,6 @@ register_unity_group( argsort_op.cc array_to_lod_tensor_op.cc assert_op.cc - assign_op.cc assign_value_op.cc attention_lstm_op.cc average_accumulates_op.cc diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 302d0ce4025..39491979018 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -131,6 +131,16 @@ extra : attrs : [bool use_mkldnn = false, bool use_cudnn = false] +- op : assign + backward : assign_grad + inputs : + x : X + outputs : + out : Out + manual_signature : [assign, assign_grad] + get_expected_kernel_type : + assign : GetAssignExpectedKernelType + - op : atan inputs : x : X diff --git a/paddle/phi/api/yaml/static_backward.yaml b/paddle/phi/api/yaml/static_backward.yaml index 3599b7064ca..825439d931e 100644 --- a/paddle/phi/api/yaml/static_backward.yaml +++ b/paddle/phi/api/yaml/static_backward.yaml @@ -1,5 +1,12 @@ # This file is to support those static ops different the dynamic. +- backward_op : assign_grad + forward : assign (Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + composite: assign_grad(out_grad, x_grad) + invoke : assign(out_grad) + - backward_op : frobenius_norm_grad forward: frobenius_norm (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index f04a2961fbf..de094829c41 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -26,6 +26,17 @@ func : all_reduce param: [x, reduce_type] +- op : assign + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : assign + optional : x + inplace : (x -> out) + backward : assign_grad + - op : broadcast args : (Tensor x, int ring_id = 0, int root = 0) output : Tensor(out) diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index 84f03869a49..09046ef4556 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -122,7 +122,9 @@ PD_REGISTER_GENERAL_KERNEL(assign_array, CPU, ALL_LAYOUT, phi::AssignArrayKernel, - ALL_DTYPE) {} + ALL_DTYPE) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(assign_value, CPU, ALL_LAYOUT, @@ -146,7 +148,9 @@ PD_REGISTER_GENERAL_KERNEL(assign_array, GPU, ALL_LAYOUT, phi::AssignArrayKernel, - ALL_DTYPE) {} + ALL_DTYPE) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(assign_value, GPU, ALL_LAYOUT, @@ -171,7 +175,9 @@ PD_REGISTER_GENERAL_KERNEL(assign_array, XPU, ALL_LAYOUT, phi::AssignArrayKernel, - ALL_DTYPE) {} + ALL_DTYPE) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} PD_REGISTER_KERNEL(assign_value, XPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/selected_rows/assign_kernel.cc b/paddle/phi/kernels/selected_rows/assign_kernel.cc index 993c5f81d34..32045d3106f 100644 --- a/paddle/phi/kernels/selected_rows/assign_kernel.cc +++ b/paddle/phi/kernels/selected_rows/assign_kernel.cc @@ -38,14 +38,18 @@ PD_REGISTER_GENERAL_KERNEL(assign_sr, CPU, ALL_LAYOUT, phi::sr::AssignKernel, - ALL_DTYPE) {} + ALL_DTYPE) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_GENERAL_KERNEL(assign_sr, GPU, ALL_LAYOUT, phi::sr::AssignKernel, - ALL_DTYPE) {} + ALL_DTYPE) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} #endif #ifdef PADDLE_WITH_XPU @@ -53,5 +57,7 @@ PD_REGISTER_GENERAL_KERNEL(assign_sr, XPU, ALL_LAYOUT, phi::sr::AssignKernel, - ALL_DTYPE) {} + ALL_DTYPE) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} #endif -- GitLab