From d8845735360b39496d30ccf0dee9b1bf8cd5c1bf Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Tue, 21 Feb 2023 21:32:52 +0800 Subject: [PATCH] Support bw invoke fw (#50260) * support bw invoke fw * fix scale in static_backward.yaml * fix the bug in tensorrt/convert * move 'scale','sign' into ops.yaml * add scale_grad of scale in op_compat.yaml * change generated_static_op in CMakeLists.txt --- paddle/fluid/eager/CMakeLists.txt | 2 +- .../tests/performance_tests/CMakeLists.txt | 2 +- paddle/fluid/framework/CMakeLists.txt | 4 +- .../ir/memory_optimize_pass/CMakeLists.txt | 2 +- paddle/fluid/jit/CMakeLists.txt | 2 +- .../fluid/operators/generator/CMakeLists.txt | 24 +++- .../fluid/operators/generator/generate_op.py | 25 ++-- .../operators/generator/generate_static_op.py | 12 ++ .../generator/templates/operator_utils.c.j2 | 25 ++++ paddle/fluid/operators/pscore/CMakeLists.txt | 12 +- paddle/fluid/operators/scale_op.cc | 118 ------------------ paddle/fluid/operators/sign_op.cc | 71 ----------- paddle/fluid/prim/tests/CMakeLists.txt | 1 - paddle/phi/api/yaml/backward.yaml | 12 ++ paddle/phi/api/yaml/legacy_backward.yaml | 12 -- paddle/phi/api/yaml/legacy_ops.yaml | 21 ---- paddle/phi/api/yaml/op_compat.yaml | 14 ++- paddle/phi/api/yaml/ops.yaml | 22 ++++ paddle/phi/api/yaml/static_backward.yaml | 1 + paddle/phi/ops/compat/scale_sig.cc | 75 ----------- 20 files changed, 130 insertions(+), 327 deletions(-) delete mode 100644 paddle/fluid/operators/scale_op.cc delete mode 100644 paddle/fluid/operators/sign_op.cc create mode 100644 paddle/phi/api/yaml/static_backward.yaml delete mode 100644 paddle/phi/ops/compat/scale_sig.cc diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index beaa82083b6..46d1ff43e0d 100755 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -77,6 +77,6 @@ cc_library( op_registry variable_helper memcpy - scale_op + generated_op autograd_meta hook_utils) diff --git a/paddle/fluid/eager/tests/performance_tests/CMakeLists.txt b/paddle/fluid/eager/tests/performance_tests/CMakeLists.txt index 40f6ac4d94d..dac2b6f2445 100644 --- a/paddle/fluid/eager/tests/performance_tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/performance_tests/CMakeLists.txt @@ -7,7 +7,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER)) ${generated_deps} eager_scale scale_node - scale_op + generated_op matmul_v2_op dygraph_function eager_prim_api) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 31c71928bb0..bced7600dc4 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1051,7 +1051,7 @@ if(WITH_PSCORE) heter_pipeline_trainer_test SRCS heter_pipeline_trainer_test.cc DEPS conditional_block_op - scale_op + generated_op heter_listen_and_serv_op executor heter_server @@ -1068,7 +1068,7 @@ if(WITH_PSCORE) heter_pipeline_trainer_test SRCS heter_pipeline_trainer_test.cc DEPS conditional_block_op - scale_op + generated_op heter_listen_and_serv_op executor heter_server diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index 32d02902e86..1723e881cd5 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -76,5 +76,5 @@ cc_library( cc_test( test_reference_count_pass_last_lived_ops SRCS test_reference_count_pass_last_lived_ops.cc - DEPS parallel_executor elementwise_mul_op elementwise_add_op scale_op + DEPS parallel_executor elementwise_mul_op elementwise_add_op generated_op eigen_function) diff --git a/paddle/fluid/jit/CMakeLists.txt b/paddle/fluid/jit/CMakeLists.txt index 150af80d5a8..c42aad1b912 100644 --- a/paddle/fluid/jit/CMakeLists.txt +++ b/paddle/fluid/jit/CMakeLists.txt @@ -65,7 +65,7 @@ if(WITH_TESTING AND NOT WIN32) reduce_mean_op feed_op fetch_op - scale_op + generated_op transfer_layout_op jit_layer) cc_test( diff --git a/paddle/fluid/operators/generator/CMakeLists.txt b/paddle/fluid/operators/generator/CMakeLists.txt index 62c11faadaf..305976c7880 100644 --- a/paddle/fluid/operators/generator/CMakeLists.txt +++ b/paddle/fluid/operators/generator/CMakeLists.txt @@ -11,6 +11,8 @@ set(legacy_bw_op_yaml_file set(sparse_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_ops.yaml) set(sparse_bw_op_yaml_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/sparse_backward.yaml) +set(static_bw_op_yaml_file + ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/static_backward.yaml) if(NOT PYTHONINTERP_FOUND) find_package(PythonInterp REQUIRED) @@ -66,6 +68,9 @@ execute_process( COMMAND ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${sparse_bw_op_yaml_file} --output_path ./parsed_ops/sparse_backward.parsed.yaml --backward + COMMAND + ${PYTHON_EXECUTABLE} parse_op.py --op_yaml_path ${static_bw_op_yaml_file} + --output_path ./parsed_ops/static_backward.parsed.yaml --backward RESULTS_VARIABLE _results) foreach(_result in ${_results}) if(${_result}) @@ -82,14 +87,24 @@ execute_process( COMMAND ${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths ./parsed_ops/ops.parsed.yaml ./parsed_ops/legacy_ops.parsed.yaml - ./parsed_ops/static_ops.parsed.yaml --backward_yaml_paths - ./parsed_ops/backward_ops.parsed.yaml + --backward_yaml_paths ./parsed_ops/backward_ops.parsed.yaml ./parsed_ops/legacy_backward_ops.parsed.yaml RESULT_VARIABLE _result) if(${_result}) message(FATAL_ERROR "ops validation failed, exiting.") endif() +execute_process( + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator + COMMAND + ${PYTHON_EXECUTABLE} cross_validate.py --forward_yaml_paths + ./parsed_ops/static_ops.parsed.yaml --backward_yaml_paths + ./parsed_ops/static_backward.parsed.yaml + RESULT_VARIABLE _result) +if(${_result}) + message(FATAL_ERROR "static ops validation failed, exiting.") +endif() + execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator COMMAND @@ -124,8 +139,9 @@ endif() execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator COMMAND - ${PYTHON_EXECUTABLE} generate_static_op.py --ops_yaml_path - ./parsed_ops/static_ops.parsed.yaml --op_version_yaml_path + ${PYTHON_EXECUTABLE} generate_op.py --ops_yaml_path + ./parsed_ops/static_ops.parsed.yaml --backward_yaml_path + ./parsed_ops/static_backward.parsed.yaml --op_version_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_version.yaml --op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml --output_op_path "${generated_static_op_path}.tmp" --output_arg_map_path diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index 2da40b1edd1..a3a85946022 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -407,6 +407,7 @@ def process_invoke_op(forward_op_dict, backward_op_dict): invoke_op = bw_op['invoke']['func'] args_list = bw_op['invoke']['args'] args_index = 0 + # backward invoke forward if invoke_op in forward_op_dict: reuse_op = forward_op_dict[invoke_op] bw_op['invoke']['func'] = reuse_op['op_name'] @@ -460,17 +461,16 @@ def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict): for bw_name in op_op['backward'].split(',') ] for bw_name in bw_names: - assert ( - bw_name in bw_op_dict - ), f"backward {bw_name} is not existed" - for out_grad in op_op['drop_empty_grad']: - assert ( - out_grad in bw_op_dict[bw_name]['output_dict'] - ), f''' - {bw_name} with {out_grad} is not existed in output_dict ''' - bw_op_dict[bw_name]['output_dict'][out_grad][ - 'drop_empty_grad' - ] = False + # static_ops.yaml and ops.yaml use the common op_compat.yaml + if bw_name in bw_op_dict: + for out_grad in op_op['drop_empty_grad']: + assert ( + out_grad in bw_op_dict[bw_name]['output_dict'] + ), f''' + {bw_name} with {out_grad} is not existed in output_dict ''' + bw_op_dict[bw_name]['output_dict'][out_grad][ + 'drop_empty_grad' + ] = False def main( @@ -493,7 +493,8 @@ def main( op_versions = yaml.safe_load(f) # add op version info into op for op_version in op_versions: - forward_op_dict[op_version['op']]['version'] = op_version['version'] + if op_version['op'] in forward_op_dict: + forward_op_dict[op_version['op']]['version'] = op_version['version'] with open(op_compat_yaml_path, "rt") as f: op_fluid_map_list = yaml.safe_load(f) diff --git a/paddle/fluid/operators/generator/generate_static_op.py b/paddle/fluid/operators/generator/generate_static_op.py index 3a825bafb12..0b79b6ab9f7 100644 --- a/paddle/fluid/operators/generator/generate_static_op.py +++ b/paddle/fluid/operators/generator/generate_static_op.py @@ -81,6 +81,7 @@ def restruct_io(op): def main( ops_yaml_path, + backward_yaml_path, op_compat_yaml_path, op_version_yaml_path, output_op_path, @@ -91,6 +92,11 @@ def main( ops = [restruct_io(op) for op in ops] forward_op_dict = to_named_dict(ops) + with open(backward_yaml_path, "rt") as f: + backward_ops = yaml.safe_load(f) + backward_ops = [restruct_io(op) for op in backward_ops] + backward_op_dict = to_named_dict(backward_ops) + with open(op_version_yaml_path, "rt") as f: op_versions = yaml.safe_load(f) @@ -139,6 +145,11 @@ if __name__ == "__main__": parser.add_argument( '--ops_yaml_path', type=str, help="parsed static ops yaml file." ) + parser.add_argument( + '--backward_yaml_path', + type=str, + help="parsed static backward ops yaml file.", + ) parser.add_argument( '--op_compat_yaml_path', type=str, help="ops args compat yaml file." ) @@ -157,6 +168,7 @@ if __name__ == "__main__": args = parser.parse_args() main( args.ops_yaml_path, + args.backward_yaml_path, args.op_compat_yaml_path, args.op_version_yaml_path, args.output_op_path, diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index 63392bb786f..ac260cd4f64 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -522,6 +522,31 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker true)}}); {% endfor %} + {% for attr in invoke_op["attrs"] %} + {% set attr_name = attr["fluid_name"] %} + {% set fw_attrs = forward_op["attrs"] %} + {% if attr_name in forward_attr_names %} + {# invoke_op's attrs and fw_attr's attrs must be the same#} + {% set fw_attr = fw_attrs[loop.index0] %} + {% if fw_attr["typename"] == "IntArray" %} + {% if 'tensor_name' in attr or 'manual_flag' not in attr %} + if (this->HasInput("{{fw_attr | to_int_array_tensor_name}}")) { + grad_op->SetInput("{{fw_attr | to_int_array_tensor_name}}", this->Input("{{fw_attr | to_int_array_tensor_name}}")); + } + {% endif %} + {% if 'tensors_name' in fw_attr or 'manual_flag' not in fw_attr %} + if (this->HasInput("{{fw_attr | to_int_array_tensors_name}}")) { + grad_op->SetInput("{{fw_attr | to_int_array_tensors_name}}", this->Input("{{fw_attr | to_int_array_tensors_name}}")); + } + {% endif %} + {% elif fw_attr["typename"] == "Scalar" %} + if (this->HasInput("{{fw_attr | to_scalar_tensor_name}}")) { + grad_op->SetInput("{{fw_attr | to_scalar_tensor_name}}", this->Input("{{fw_attr | to_scalar_tensor_name}}")); + } + {% endif %} + {% endif %} + {% endfor %} + {% for attr in invoke_op["attrs"] %} grad_op->SetAttr("{{attr["fluid_name"]}}", {{attr["value"]}}); {% endfor %} diff --git a/paddle/fluid/operators/pscore/CMakeLists.txt b/paddle/fluid/operators/pscore/CMakeLists.txt index 1c4771c7b4d..89b33d9a144 100755 --- a/paddle/fluid/operators/pscore/CMakeLists.txt +++ b/paddle/fluid/operators/pscore/CMakeLists.txt @@ -86,7 +86,7 @@ cc_test_old( executor scope proto_desc - scale_op + generated_op eigen_function) set_source_files_properties( @@ -100,7 +100,7 @@ cc_test_old( executor scope proto_desc - scale_op + generated_op send_and_recv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} @@ -117,7 +117,7 @@ cc_test_old( executor scope proto_desc - scale_op + generated_op send_and_recv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} @@ -134,14 +134,14 @@ cc_test_old( executor scope proto_desc - scale_op + generated_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) #set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) +#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc generated_static_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) set_source_files_properties( switch_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) @@ -153,7 +153,7 @@ cc_binary( executor scope proto_desc - scale_op + generated_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc deleted file mode 100644 index 2cfd0969865..00000000000 --- a/paddle/fluid/operators/scale_op.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class ScaleOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor) Input tensor of scale operator."); - AddInput("ScaleTensor", - "(Tensor) If provided, use this as " - "scale factor, this has a higher priority than " - "attr(scale), the shape of this tensor MUST BE 1.") - .AsDispensable(); - AddOutput("Out", "(Tensor) Output tensor of scale operator."); - AddComment(R"DOC( -**Scale operator** - -Apply scaling and bias addition to the input tensor. - -if bias_after_scale=True: - -$$Out = scale*X + bias$$ - -else: - -$$Out = scale*(X + bias)$$ -)DOC"); - AddAttr("scale", "The scaling factor of the scale operator.") - .SetDefault(1.0); - AddAttr("bias", "The bias of the scale operator.").SetDefault(0.0); - AddAttr( - "bias_after_scale", - "Apply bias addition after or before scaling. It is useful for " - "numeric stability in some circumstances.") - .SetDefault(true); - } -}; - -class ScaleOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - ctx->SyncTypeAndDataType("X", "Out"); - } -}; - -template -class ScaleGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("scale"); - grad_op->SetInput("X", this->OutputGrad("Out")); - if (this->HasInput("ScaleTensor") > 0) { - grad_op->SetInput("ScaleTensor", this->Input("ScaleTensor")); - } - grad_op->SetOutput("Out", this->InputGrad("X")); - VLOG(6) << "Finish SetOutput"; - grad_op->SetAttr("scale", this->GetAttr("scale")); - VLOG(6) << "Finish Set Attr scale"; - grad_op->SetAttr("bias", 0.0f); - VLOG(6) << "Finish Set Attr bias"; - grad_op->SetAttr("bias_after_scale", true); - VLOG(6) << "Finish Set Attr bias_after_scale"; - VLOG(6) << "Finish Apply"; - } -}; - -DECLARE_INPLACE_OP_INFERER(ScaleOpInplaceInferer, {"X", "Out"}); -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(scale, - ScaleInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); -REGISTER_OPERATOR(scale, - ops::ScaleOp, - ops::ScaleOpMaker, - ops::ScaleGradMaker, - ops::ScaleGradMaker, - ScaleInferShapeFunctor, - ops::ScaleOpVarTypeInference, - ops::ScaleOpInplaceInferer); diff --git a/paddle/fluid/operators/sign_op.cc b/paddle/fluid/operators/sign_op.cc deleted file mode 100644 index f6a7074df10..00000000000 --- a/paddle/fluid/operators/sign_op.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class SignOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -template -class SignOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor) Input tensor of sign operator."); - AddOutput("Out", "(Tensor) Output tensor of sign operator."); - AddComment(R"DOC( -Sign operator - -$$Out = X.sign()$$ -)DOC"); - } -}; - -template -class SignGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("scale"); - grad_op->SetInput("X", this->OutputGrad("Out")); - grad_op->SetOutput("Out", this->InputGrad("X")); - grad_op->SetAttr("scale", 0.0f); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(sign, - SignInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); -REGISTER_OPERATOR(sign, - ops::SignOp, - ops::SignOpMaker, - ops::SignGradMaker, - ops::SignGradMaker, - SignInferShapeFunctor); diff --git a/paddle/fluid/prim/tests/CMakeLists.txt b/paddle/fluid/prim/tests/CMakeLists.txt index cc082a84299..f85108a7c8e 100644 --- a/paddle/fluid/prim/tests/CMakeLists.txt +++ b/paddle/fluid/prim/tests/CMakeLists.txt @@ -29,7 +29,6 @@ cc_test_old( prim_utils operator elementwise_mul_op - scale_op activation_op phi_api phi_dygraph_api diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 5b900da998c..9cfcfdea899 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1103,6 +1103,12 @@ backward : rsqrt_double_grad inplace : (out_grad -> x_grad) +- backward_op : scale_grad + forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out) + args : (Tensor out_grad, Scalar scale=1.0) + output : Tensor(x_grad) + invoke : scale(out_grad, scale, 0.0f, true) + - backward_op : scatter_grad forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite=true) -> Tensor(out) args : (Tensor index, Tensor updates, Tensor out_grad, bool overwrite) @@ -1207,6 +1213,12 @@ optional : grad_grad_out_grad inplace : (grad_grad_x -> fwd_grad_out_grad) +- backward_op : sign_grad + forward : sign (Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + invoke : scale(out_grad, 0.0f, 0.0f, true) + - backward_op : silu_grad forward : silu (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index c6cbaaefdc5..b84f011296d 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1144,12 +1144,6 @@ func : rrelu_grad data_type : x -- backward_op : scale_grad - forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out) - args : (Tensor out_grad, Scalar scale=1.0, bool bias_after_scale=true) - output : Tensor(x_grad) - invoke : scale(out_grad, scale, 0.0, bias_after_scale) - - backward_op : segment_pool_grad forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype) -> Tensor(out), Tensor(summed_ids) args : (Tensor x, Tensor segment_ids, Tensor out, Tensor summed_ids, Tensor out_grad, str pooltype) @@ -1173,12 +1167,6 @@ func : sigmoid_cross_entropy_with_logits_grad inplace : (out_grad -> x_grad) -- backward_op : sign_grad - forward : sign (Tensor x) -> Tensor(out) - args : (Tensor out_grad) - output : Tensor(x_grad) - invoke : scale(out_grad, 0.0, 0.0, true) - - backward_op : slice_double_grad forward : slice_grad (Tensor input, Tensor grad_out, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(grad_input) args : (Tensor grad_input_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index e6b7124f79c..2f59c858a03 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1559,18 +1559,6 @@ intermediate : noise backward : rrelu_grad -- op : scale - args : (Tensor x, Scalar scale, float bias, bool bias_after_scale) - output : Tensor(out) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : scale {dense -> dense}, - scale_sr {selected_rows -> selected_rows} - inplace : (x -> out) - backward : scale_grad - - op : segment_pool args : (Tensor x, Tensor segment_ids, str pooltype) output : Tensor(out), Tensor(summed_ids) @@ -1616,15 +1604,6 @@ func : sigmoid_cross_entropy_with_logits backward : sigmoid_cross_entropy_with_logits_grad -- op : sign - args : (Tensor x) - output : Tensor(out) - infer_meta : - func : UnchangedInferMeta - kernel : - func : sign - backward : sign_grad - - op : slice args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) output : Tensor diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 3f4cfbdc25b..a34ee5471a5 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1326,10 +1326,15 @@ attrs : [bool use_mkldnn = false, bool use_cudnn = false] - op : scale + backward : scale_grad inputs : x : X outputs : - out: Out + out : Out + scalar : + scale : + data_type : float + tensor_name : ScaleTensor extra : attrs : [bool use_mkldnn = false] @@ -1425,6 +1430,13 @@ extra : attrs : [bool use_mkldnn = false, bool use_cudnn = false] +- op : sign + backward : sign_grad + inputs : + x : X + outputs : + out : Out + - op : silu backward : silu_grad inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index d203bda3d74..b783601e6d0 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1022,6 +1022,19 @@ inplace : (x -> out) backward : rsqrt_grad +- op : scale + args : (Tensor x, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : scale {dense -> dense}, + scale_sr {selected_rows -> selected_rows} + data_type : x + inplace : (x -> out) + backward : scale_grad + - op : scatter args : (Tensor x, Tensor index, Tensor updates, bool overwrite=true) output : Tensor(out) @@ -1111,6 +1124,15 @@ func : sigmoid backward : sigmoid_grad +- op : sign + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : sign + backward : sign_grad + - op : silu args : (Tensor x) output : Tensor diff --git a/paddle/phi/api/yaml/static_backward.yaml b/paddle/phi/api/yaml/static_backward.yaml new file mode 100644 index 00000000000..8c9cbb6cfb3 --- /dev/null +++ b/paddle/phi/api/yaml/static_backward.yaml @@ -0,0 +1 @@ +# This file is to support those static ops different the dynamic. diff --git a/paddle/phi/ops/compat/scale_sig.cc b/paddle/phi/ops/compat/scale_sig.cc deleted file mode 100644 index 8061a1fbd61..00000000000 --- a/paddle/phi/ops/compat/scale_sig.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -/** - * Note [ Why does the ArgumentMapping function need to be so complicated? ] - * - * In order to meet the requirements of infrt, the function used to match Op - * and Kernel parameters, need to be placed in phi as a compatible component, - * and does not depend on fluid. - * - * Because infrt not only needs to dynamically call this argument mapping - * function at runtime, but also needs to statically declare all possible - * results of the function before running without any information. - * - * The infrt declare like: - * - * def PDKEL_Reshape_to_CPU : Pat< - * (PD_ReshapeOp $x, $shape_tensor, $shape_attr), // OpMaker arguments - * (PDKEL_ReshapeKernelAttr $x, fn($shape_attr)>; // Kernel arguments - * def PDKEL_Reshape_to_CPU : Pat< - * (PD_ReshapeOp $x, $shape_tensor, $shape_attr), - * (PDKEL_ReshapeKernelAttr $x, fn($shape_tensor)>; - * - * Therefore, we need to write out each result of the argument mapping function, - * like `KernelSignature("full", {}, {"ShapeTensor", "value"}, {"Out"})`, it - * cannot contains variable, only can contains const char* string. - * - * Infrt will parse all results before running for the generation of the above - * static declare, which leads to some functions being written in a long way, - * and the complicated ones may have hundreds of lines, which has certain side - * effects on the programming experience. - */ -KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.IsDenseTensorInput("X")) { - if (ctx.HasInput("ScaleTensor")) { - return KernelSignature( - "scale", {"X"}, {"ScaleTensor", "bias", "bias_after_scale"}, {"Out"}); - } else { - return KernelSignature( - "scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"}); - } - } else if (ctx.IsSelectedRowsInput("X")) { - if (ctx.HasInput("ScaleTensor")) { - return KernelSignature("scale_sr", - {"X"}, - {"ScaleTensor", "bias", "bias_after_scale"}, - {"Out"}); - } else { - return KernelSignature( - "scale_sr", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"}); - } - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -} // namespace phi - -// op_type, api_name, arg_mapping_fn -PD_REGISTER_ARG_MAPPING_FN(scale, phi::ScaleOpArgumentMapping); -- GitLab