diff --git a/paddle/fluid/operators/generator/filters.py b/paddle/fluid/operators/generator/filters.py index da33c03e0c76e83b8af2d1cacae722eadca4408c..6520a2e07511f2d8a9cb24e42ff4ae5c28fcd040 100644 --- a/paddle/fluid/operators/generator/filters.py +++ b/paddle/fluid/operators/generator/filters.py @@ -65,6 +65,25 @@ def to_dense_input_type(s, optional=False): return dense_optional_input_types_map[s] +def assert_dense_or_sr(input_type): + return ( + "ctx.IsSelectedRowsInput" + if input_type == "selected_rows" + else "ctx.IsDenseTensorInput" + ) + + +def find_optinal_inputs_name(inputs): + optional_inputs_name = [ + input["fluid_name"] for input in inputs if input["optional"] is True + ] + return optional_inputs_name + + +def delete_last_underline(op_name): + return op_name if op_name[-1] != '_' else op_name[:-1] + + # ------------------------------ output ---------------------------------- def to_paddle_output_type(s): return output_type_map[s] diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index 96df7e71a5ea3c40d2662c5d005d3caad3ecac54..223e6b714ae4cdb18130cb3f47e71d091a49e731 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -19,7 +19,10 @@ from pathlib import Path import yaml from filters import ( + assert_dense_or_sr, cartesian_prod_mapping, + delete_last_underline, + find_optinal_inputs_name, to_composite_grad_opmaker_name, to_input_name, to_int_array_tensor_name, @@ -63,6 +66,8 @@ env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name env.filters["to_variable_names"] = to_variable_names +env.filters["assert_dense_or_sr"] = assert_dense_or_sr +env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name env.tests["base_op"] = is_base_op env.tests["composite_op"] = is_composite_op env.tests["vec"] = is_vec @@ -478,19 +483,20 @@ def parse_get_expected_kerneltype( fw_op_dict[fw_name][ "get_expected_kernel_type" ] = op_comp_map['get_expected_kernel_type'][fw_name] - bw_names = [ - bw_name.split('(')[0].strip() - for bw_name in op_comp_map['backward'].split(',') - ] - for bw_name in bw_names: - # static_ops.yaml and ops.yaml use the common op_compat.yaml - if ( - bw_name in bw_op_dict - and bw_name in op_comp_map['get_expected_kernel_type'] - ): - bw_op_dict[bw_name][ - "get_expected_kernel_type" - ] = op_comp_map['get_expected_kernel_type'][bw_name] + if "backward" in op_comp_map: + bw_names = [ + bw_name.split('(')[0].strip() + for bw_name in op_comp_map['backward'].split(',') + ] + for bw_name in bw_names: + # static_ops.yaml and ops.yaml use the common op_compat.yaml + if ( + bw_name in bw_op_dict + and bw_name in op_comp_map['get_expected_kernel_type'] + ): + bw_op_dict[bw_name][ + "get_expected_kernel_type" + ] = op_comp_map['get_expected_kernel_type'][bw_name] def parse_keep_signature( @@ -528,6 +534,20 @@ def split_ops_list(ops, backward_op_dict, split_num): return new_ops_list, new_bw_ops_list +def to_phi_and_fluid_op_name_without_underline(op_item): + ''' + If the op_name ends with '_', delete the last '_'. For an example, 'sgd_' becomes 'sgd + ''' + names = op_item.split('(') + if len(names) == 1: + op_kernel_name = delete_last_underline(names[0].strip()) + return op_kernel_name + else: + op_name = delete_last_underline(names[0].strip()) + kernel_name = delete_last_underline(names[1].split(')')[0].strip()) + return op_name + '(' + kernel_name + ')' + + def main( ops_yaml_path, backward_yaml_path, @@ -539,11 +559,11 @@ def main( with open(ops_yaml_path, "rt") as f: ops = yaml.safe_load(f) ops = [restruct_io(op) for op in ops] - forward_op_dict = to_named_dict(ops) + forward_op_dict = to_named_dict(ops, True) with open(backward_yaml_path, "rt") as f: backward_ops = yaml.safe_load(f) backward_ops = [restruct_io(op) for op in backward_ops] - backward_op_dict = to_named_dict(backward_ops) + backward_op_dict = to_named_dict(backward_ops, True) with open(op_version_yaml_path, "rt") as f: op_versions = yaml.safe_load(f) # add op version info into op @@ -553,6 +573,10 @@ def main( with open(op_compat_yaml_path, "rt") as f: op_fluid_map_list = yaml.safe_load(f) + for op_args in op_fluid_map_list: + op_args["op"] = to_phi_and_fluid_op_name_without_underline( + op_args["op"] + ) for op in ops: op['op_name'] = op['name'] diff --git a/paddle/fluid/operators/generator/generate_sparse_op.py b/paddle/fluid/operators/generator/generate_sparse_op.py index 1991b1fd2227df93ab296188f7407081474f58da..cee478e0df6277c6e36e0af9490e176f64aec9e0 100644 --- a/paddle/fluid/operators/generator/generate_sparse_op.py +++ b/paddle/fluid/operators/generator/generate_sparse_op.py @@ -18,7 +18,9 @@ from pathlib import Path import yaml from filters import ( + assert_dense_or_sr, cartesian_prod_mapping, + find_optinal_inputs_name, to_composite_grad_opmaker_name, to_input_name, to_int_array_tensor_name, @@ -59,6 +61,8 @@ env.filters["to_scalar_tensor_name"] = to_scalar_tensor_name env.filters["to_int_array_tensor_name"] = to_int_array_tensor_name env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name env.filters["to_input_name"] = to_input_name +env.filters["assert_dense_or_sr"] = assert_dense_or_sr +env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.cc b/paddle/fluid/operators/generator/get_expected_kernel_func.cc index dc08e0b5cec2a910d3fdf081928938abfcdad369..db963aa27ae53db7f9293c188d1cd525eb7de419 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.cc +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.cc @@ -101,5 +101,27 @@ phi::KernelKey GetReduceGradExpectedKernelType( return phi::KernelKey(input_data_type, ctx.GetPlace()); } +phi::KernelKey GetSgdExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + auto data_type = op_ptr->IndicateVarDataType(ctx, "Param"); + + // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN + const auto* param_var = ctx.InputVar("Param"); + const auto* grad_var = ctx.InputVar("Grad"); + + // supported cases + bool dense_param_sparse_grad = param_var->IsType() && + grad_var->IsType(); + bool dense_param_and_grad = param_var->IsType() && + grad_var->IsType(); + if (!(dense_param_sparse_grad || dense_param_and_grad)) { + op_ptr->SetDnnFallback(true); + } + // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN + + return phi::KernelKey(data_type, ctx.GetPlace()); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.h b/paddle/fluid/operators/generator/get_expected_kernel_func.h index 2054d593fb33608a07602464206fb62321bbc75e..a5883f44d7c9f905a2e23f2866af3642318177c2 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.h +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.h @@ -28,5 +28,9 @@ phi::KernelKey GetReduceGradExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); +phi::KernelKey GetSgdExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/generator/ops_extra_info_gen.py b/paddle/fluid/operators/generator/ops_extra_info_gen.py index 11856daeba2fa4c19707d3b1b576ce131a392ea4..a6482908ba631e1eb822160f22632d2859a74932 100644 --- a/paddle/fluid/operators/generator/ops_extra_info_gen.py +++ b/paddle/fluid/operators/generator/ops_extra_info_gen.py @@ -16,6 +16,7 @@ import argparse import re import yaml +from filters import delete_last_underline def map_code_template(attrs_str, attrs_checker_str): @@ -77,9 +78,9 @@ def generate_extra_info(op_compat_yaml_path, ops_extra_info_path): def get_op_name(api_item): names = api_item.split('(') if len(names) == 1: - return names[0].strip() + return delete_last_underline(names[0].strip()) else: - return names[1].split(')')[0].strip() + return delete_last_underline(names[1].split(')')[0].strip()) extra_map_str_list = [] extra_checker_str_list = [] diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 1329e940446ade392d68f759e62a18a3c5f9917d..aef27367a904174e70f9159d9d88950fbd847c40 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -20,13 +20,23 @@ from tests_utils import is_attr, is_input, is_output, is_vec from type_mapping import opmaker_attr_types_map -def to_named_dict(items: List[Dict]) -> Dict[str, Dict]: +def to_named_dict(items: List[Dict], is_op=False) -> Dict[str, Dict]: named_dict = {} - for item in items: - if "name" not in item: - raise KeyError(f"name not in {item}") - name = item["name"] - named_dict[name] = item + if is_op: + for item in items: + if "name" not in item: + raise KeyError(f"name not in {item}") + item["name"] = ( + item["name"] if item["name"][-1] != '_' else item["name"][:-1] + ) + name = item["name"] + named_dict[name] = item + else: + for item in items: + if "name" not in item: + raise KeyError(f"name not in {item}") + name = item["name"] + named_dict[name] = item return named_dict diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index 33233bd77a141d15a75c2377baada2ca18d00d5c..e252e42ca7bac3189d70740539341daa290404a8 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -119,6 +119,34 @@ static_cast(phi::Place({{"phi::" if not default_value is initializer_list}} {%- endmacro %} +{% macro choose_kernel_signature(inputs_type, optional_inputs_name, kernel_args, kernel_name, is_first) %} {#inline#} + {%- set inputs_len = inputs_type | length -%} + {%- if is_first -%} +if ( + {%- else -%} +else if ( + {%- endif -%} + {%- for input_type in inputs_type -%} + {%- set kernel_arg_name = kernel_args[loop.index0] | to_opmaker_name_cstr -%} + {%- if loop.index0 != inputs_len - 1%} + {%- if kernel_args[loop.index0] in optional_inputs_name %} + ((ctx.HasInput({{kernel_arg_name}}) && {{input_type | assert_dense_or_sr}}({{kernel_arg_name}})) || (!ctx.HasInput({{kernel_arg_name}}))) && + {% else %} + {{input_type | assert_dense_or_sr}}({{kernel_arg_name}}) && + {% endif %} + {% else %} {# the last param #} + {% if kernel_args[loop.index0] in optional_inputs_name -%} + ((ctx.HasInput({{kernel_arg_name}}) && {{input_type | assert_dense_or_sr}}({{kernel_arg_name}})) || (!ctx.HasInput({{kernel_arg_name}})))) + {%- else -%} + {{input_type | assert_dense_or_sr}}({{kernel_arg_name}})) + {%- endif %} { + return KernelSignature("{{kernel_name}}", std::move(inputs), std::move(attrs), std::move(outputs)); + } + {% endif %} + {% endfor %} +{%- endmacro -%} + + {# --------------------------------------- name mapping ---------------------------------------------- #} {% macro name_map(op) %} /* @@ -132,21 +160,39 @@ All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgu KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { {% set kernel_args = op["kernel"]["param"] %} + {% set optional_inputs_name = op["inputs"]| find_optinal_inputs_name %} {{get_input_list(op["inputs"], kernel_args)}}; paddle::small_vector attrs; {% for attr in op["attrs"]%} - {% filter indent(2)%} + {% filter indent(2)%} {{get_an_attr(attr, kernel_args)}} - {% endfilter %} + {% endfilter %} {% endfor %} {{get_output_list(op["outputs"], kernel_args)}}; - {% if op["kernel"]["func"] | length == 1 %} - KernelSignature sig("{{op["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs)); - return sig; - {% else %}{# it has kernel for selected rows #} - const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{op["kernel"]["func"][1]}}" : "{{op["kernel"]["func"][0]}}"; - KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs)); - return sig; + {% set kernel_num = op["kernel"]["func"] | length %} + {% if kernel_num == 1 %} + return KernelSignature("{{op["kernel"]["func"][0]}}", std::move(inputs), std::move(attrs), std::move(outputs)); + {% elif kernel_num == 2 %}{# it has kernel for selected rows #} + {% set fun_name = op["kernel"]["func"][0] %} + {% set inputs_type = op["kernel"]["dispatch"][fun_name][0] %} +{{choose_kernel_signature(inputs_type, optional_inputs_name, kernel_args, fun_name, true)}} + {%- set fun_name = op["kernel"]["func"][1] -%} + {%- set inputs_type = op["kernel"]["dispatch"][fun_name][0] -%} +{{choose_kernel_signature(inputs_type, optional_inputs_name, kernel_args, fun_name, false)-}} + else { return KernelSignature("unregistered", {}, {}, {}); } + {% elif kernel_num == 3 %}{# it has kernel for selected rows #} + {%- set fun_name = op["kernel"]["func"][0] -%} + {%- set inputs_type = op["kernel"]["dispatch"][fun_name][0] -%} +{{choose_kernel_signature(inputs_type, optional_inputs_name, kernel_args, fun_name, true)}} + {%- set fun_name = op["kernel"]["func"][1] -%} + {%- set inputs_type = op["kernel"]["dispatch"][fun_name][0] -%} +{{choose_kernel_signature(inputs_type, optional_inputs_name, kernel_args, fun_name, false)-}} + {%- set fun_name = op["kernel"]["func"][2] -%} + {%- set inputs_type = op["kernel"]["dispatch"][fun_name][0] -%} +{{choose_kernel_signature(inputs_type, optional_inputs_name, kernel_args, fun_name, false)-}} + else { return KernelSignature("unregistered", {}, {}, {}); } + {% else %} {# only support kernel_num <= 3 #} + return KernelSignature("unregistered", {}, {}, {}); {%endif%} } {% endmacro %} @@ -395,7 +441,7 @@ class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKerne DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor, PD_INFER_META(phi::{{op["infer_meta"]["func"]}})); {# inplace inferer #} -{% if op["inplace"] is not none %} +{% if op["inplace"] is not none and op["inplace"] | length == 1%} {% set inplace_map %} {% for source, target in op["inplace"].items() %} {{"{"}}{{target | to_opmaker_name}}, {{source | to_opmaker_name}}{{"}"}}{{", " if not loop.last}} diff --git a/paddle/fluid/operators/generator/tests_utils.py b/paddle/fluid/operators/generator/tests_utils.py index 424ecb09c2d2c588e975b3a2b039b036c649fb05..574f3663b7d5603f18518d862023b9a296970442 100644 --- a/paddle/fluid/operators/generator/tests_utils.py +++ b/paddle/fluid/operators/generator/tests_utils.py @@ -63,7 +63,7 @@ def supports_selected_rows_kernel(op): def supports_inplace(op): - return op['inplace'] is not None + return op['inplace'] is not None and len(op['inplace']) == 1 def supports_no_need_buffer(op): diff --git a/paddle/fluid/operators/optimizers/lamb_op.cc b/paddle/fluid/operators/optimizers/lamb_op.cc deleted file mode 100644 index c6c4397332280ef96a671de96f14e371bd107695..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/lamb_op.cc +++ /dev/null @@ -1,156 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/infermeta/multiary.h" -#include "paddle/phi/kernels/lamb_kernel.h" - -namespace paddle { -namespace operators { - -class LambOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const { - auto input_data_type = - OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } - phi::KernelKey GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const phi::KernelKey &expected_kernel_type) const { - if (var_name == "Beta1Pow" || var_name == "Beta2Pow") { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } else { - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } - } -}; - -class LambOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Param", - "(phi::DenseTensor, default phi::DenseTensor) " - "Input parameter that has to be updated."); - AddInput("Grad", - "(phi::DenseTensor, default phi::DenseTensor) " - "Input gradient of the parameter."); - AddInput("LearningRate", "(Tensor) Learning rate."); - AddInput("Moment1", "(Tensor) Input first moment."); - AddInput("Moment2", "(Tensor) Input second moment."); - AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator."); - AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator."); - AddInput("MasterParam", - "(phi::DenseTensor, default phi::DenseTensor) " - "Input master parameter that has to be updated.") - .AsDispensable(); - AddInput( - "SkipUpdate", - "(Tensor) Input tensor to determine whether to update the parameter.") - .AsDispensable(); - - AddOutput("ParamOut", "(Tensor) Output parameter."); - AddOutput("Moment1Out", "(Tensor) Output first moment."); - AddOutput("Moment2Out", "(Tensor) Output second moment."); - AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator") - .AsDispensable(); - AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator") - .AsDispensable(); - AddOutput("MasterParamOut", "(Tensor) Output master parameter.") - .AsDispensable(); - AddAttr("weight_decay", "(float) Weight decay rate."); - AddAttr("beta1", - "(float, default 0.9) The exponential decay rate for the " - "1st moment estimates.") - .SetDefault(0.9); - AddAttr("beta2", - "(float, default 0.999) The exponential decay rate for the " - "2nd moment estimates.") - .SetDefault(0.999); - AddAttr("epsilon", - "(float, default 1.0e-6) " - "Constant for numerical stability.") - .SetDefault(1.0e-6f); - AddAttr( - "multi_precision", - "(bool, default false) Whether to enable multi-precision mode.") - .SetDefault(false); - - AddComment(R"DOC( -LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer. - -LAMB Optimizer is designed to scale up the batch size of training without losing -accuracy, which supports adaptive element-wise updating and accurate layer-wise -correction. For more information, please refer to https://arxiv.org/abs/1904.00962. - -The updating of parameters follows: - -$$ -m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t \\ - -v_t &= \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2 \\ - -m_t &= \frac{m_t}{\beta_1^t} \\ - -v_t &= \frac{v_t}{\beta_2^t} \\ - -r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon} \\ - -w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1}) -$$ - -where $m$ is the 1st moment, and $v$ the 2nd moment, $\eta$ the -learning rate, $\lambda$ the weight decay rate. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(lamb, - LambInferMetaFunctor, - PD_INFER_META(phi::LambInferMeta)); -REGISTER_OPERATOR( - lamb, - ops::LambOp, - ops::LambOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - LambInferMetaFunctor); - -/* ========================== register checkpoint ===========================*/ -REGISTER_OP_VERSION(lamb).AddCheckpoint( - R"ROC(Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut].)ROC", - paddle::framework::compatible::OpVersionDesc() - .NewInput("Beta1PowOut", - "The Output beta1 power accumulator. 'Beta1PowOut' is " - "dispensable.") - .NewInput("Beta2PowOut", - "The Output beta2 power accumulator. 'Beta2PowOut' is " - "dispensable.")); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc deleted file mode 100644 index ac445d30c31afe97aa3dad23646c789ddd6962b3..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/optimizers/sgd_op.h" - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/multiary.h" - -namespace paddle { -namespace operators { - -class SGDOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - - // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN - const auto *param_var = ctx.InputVar("Param"); - const auto *grad_var = ctx.InputVar("Grad"); - - // supported cases - bool dense_param_sparse_grad = param_var->IsType() && - grad_var->IsType(); - bool dense_param_and_grad = param_var->IsType() && - grad_var->IsType(); - if (!(dense_param_sparse_grad || dense_param_and_grad)) { - this->SetDnnFallback(true); - } - // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN - - return phi::KernelKey(data_type, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const phi::KernelKey &expected_kernel_type) const override { - if (var_name == "LearningRate") { - return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype()); - } - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } -}; - -class SGDOpInferVarType : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto in_var_type = ctx->GetInputType("Param"); - PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS || - in_var_type == framework::proto::VarType::LOD_TENSOR, - true, - platform::errors::InvalidArgument( - "The input Var's type should be LoDtensor or " - "SelectedRows, but the received type is %s", - in_var_type)); - - ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS); - } -}; - -class SGDOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Param", "(Tensor or SelectedRows) Input parameter"); - AddInput("LearningRate", "(Tensor) Learning rate of SGD"); - AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); - AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); - AddOutput("ParamOut", - "(Tensor or SelectedRows, same with Param) " - "Output parameter, should share the same memory with Param"); - AddOutput("MasterParamOut", - "The updated FP32 master weight for AMP. " - "It shared memory with Input(MasterParam).") - .AsDispensable(); - - AddAttr( - "use_mkldnn", - "(bool, default false) Indicates if MKL-DNN kernel will be used") - .SetDefault(false); - AddAttr("multi_precision", - "(bool, default false) " - "Whether to use multi-precision during weight updating.") - .SetDefault(false); - - AddComment(R"DOC( - -SGD operator - -This operator implements one step of the stochastic gradient descent algorithm. - -$$param\_out = param - learning\_rate * grad$$ - -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(sgd, - SGDInferShapeFunctor, - PD_INFER_META(phi::SgdInferMeta)); -REGISTER_OPERATOR( - sgd, - ops::SGDOp, - ops::SGDOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - ops::SGDOpInferVarType, - SGDInferShapeFunctor); diff --git a/paddle/fluid/operators/optimizers/unity_build_rule.cmake b/paddle/fluid/operators/optimizers/unity_build_rule.cmake index 61e63ad9a6e61e682dad7fd525905eb2722a0305..05daf4cad0cf8a2900c4ebf5915dd2e9a6b34f46 100644 --- a/paddle/fluid/operators/optimizers/unity_build_rule.cmake +++ b/paddle/fluid/operators/optimizers/unity_build_rule.cmake @@ -9,7 +9,6 @@ register_unity_group( ftrl_op.cc lars_momentum_op.cc momentum_op.cc - sgd_op.cc proximal_adagrad_op.cc adagrad_op.cc adam_op.cc @@ -18,7 +17,6 @@ register_unity_group( proximal_gd_op.cc decayed_adagrad_op.cc adadelta_op.cc - lamb_op.cc dpsgd_op.cc rmsprop_op.cc) register_unity_group( diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 9f30d076a5eb6d972c336e6c2aa42d11184ae203..46a1e965e21c35659a6fe9d0eb058399203b9148 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -853,18 +853,6 @@ data_type : x backward : kldiv_loss_grad -- op : lamb_ - args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1, float beta2, float epsilon, bool multi_precision) - output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) - infer_meta : - func : LambInferMeta - kernel : - func : lamb {dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense}, - lamb_sr {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense} - data_type : param - optional : master_param, skip_update - inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs) - - op : layer_norm args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) output : Tensor(out), Tensor(mean), Tensor(variance) @@ -1455,21 +1443,6 @@ data_type : x backward : segment_pool_grad -- op : sgd_ - args : (Tensor param, Tensor learning_rate, Tensor grad, Tensor master_param, bool multi_precision) - output : Tensor(param_out), Tensor(master_param_out) - infer_meta : - func : SgdInferMeta - kernel : - func : sgd {dense, dense, dense, dense -> dense, dense}, - sgd_dense_param_sparse_grad {dense, dense, selected_rows, dense -> dense, dense}, - sgd_sparse_param_sparse_grad {selected_rows, dense, selected_rows, selected_rows -> selected_rows, selected_rows} - data_type : param - data_transform : - support_trans_dtype : learning_rate - optional : master_param - inplace : (param -> param_out), (master_param -> master_param_out) - - op : shape args : (Tensor input) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 6e957f374941b18eb19ff72d805fa7633efee936..7409fca6980fca124da7db6e72989fc7728a3bd0 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -984,6 +984,12 @@ outputs : out : Out +- op : lamb_ + inputs : + {param : Param, grad : Grad, learning_rate : LearningRate, moment1 : Moment1, moment2 : Moment2, beta1_pow : Beta1Pow, beta2_pow : Beta2Pow, master_param : MasterParam, skip_update : SkipUpdate} + outputs : + {param_out : ParamOut, moment1_out : Moment1Out, moment2_out : Moment2Out, beta1_pow_out : Beta1PowOut, beta2_pow_out : Beta2PowOut, master_param_outs : MasterParamOut} + - op : layer_norm backward : layer_norm_grad inputs : @@ -1578,6 +1584,16 @@ extra : attrs : [str data_format = "AnyLayout"] +- op : sgd_ + inputs : + {param : Param, learning_rate : LearningRate, grad : Grad, master_param : MasterParam} + outputs : + {param_out : ParamOut, master_param_out : MasterParamOut} + get_expected_kernel_type : + sgd : GetSgdExpectedKernelType #"sgd_" becomes "sgd" + extra : + attrs : [bool use_mkldnn=false] + - op : shape extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 2851b86615b74fcd2298dd4ccae3e21fad98296a..83860ce5c8f07c4e79f41589a685fcb9bf591f0e 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -78,6 +78,15 @@ comment : In order to specify interpolation mode default : std::string("bilinear") +- op : lamb + version : + - checkpoint : Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut]. + action : + - add_output : Beta1PowOut + comment : The Output beta1 power accumulator. 'Beta1PowOut' is dispensable. + - add_output : Beta2PowOut + comment : The Output beta2 power accumulator. 'Beta2PowOut' is dispensable. + - op : less_equal version : - checkpoint : Upgrade compare ops, add a new attribute [force_cpu] diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e953047639f67e80277782b06bafb52c7834c928..0afba610c5cda3ac845e89226f593c7f2816b494 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -835,6 +835,18 @@ optional : prior_dist backward : label_smooth_grad +- op : lamb_ + args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1=0.9, float beta2=0.999, float epsilon=1.0e-6f, bool multi_precision=false) + output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) + infer_meta : + func : LambInferMeta + kernel : + func : lamb {dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense}, + lamb_sr {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense} + data_type : param + optional : master_param, skip_update, beta1_pow_out, beta2_pow_out, master_param_outs + inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs) + - op : leaky_relu args : (Tensor x, float negative_slope = 0.02f) output : Tensor @@ -1312,6 +1324,21 @@ data_type : x backward : send_uv_grad +- op : sgd_ + args : (Tensor param, Tensor learning_rate, Tensor grad, Tensor master_param, bool multi_precision=false) + output : Tensor(param_out), Tensor(master_param_out) + infer_meta : + func : SgdInferMeta + kernel : + func : sgd {dense, dense, dense, dense -> dense, dense}, + sgd_dense_param_sparse_grad {dense, dense, selected_rows, dense -> dense, dense}, + sgd_sparse_param_sparse_grad {selected_rows, dense, selected_rows, selected_rows -> selected_rows, selected_rows} + data_type : param + data_transform : + support_trans_dtype : learning_rate + optional : master_param, master_param_out + inplace : (param -> param_out), (master_param -> master_param_out) + - op : shard_index args : (Tensor input, int index_num, int nshards, int shard_id, int ignore_value=-1) output : Tensor(out) diff --git a/paddle/phi/ops/compat/lamb_sig.cc b/paddle/phi/ops/compat/lamb_sig.cc deleted file mode 100644 index a59ae6155c1832114c4e06161312d755c66909cd..0000000000000000000000000000000000000000 --- a/paddle/phi/ops/compat/lamb_sig.cc +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include - -#include "paddle/phi/core/compat/op_utils.h" -#include "paddle/utils/small_vector.h" - -namespace phi { - -KernelSignature LambOpArgumentMapping(const ArgumentMappingContext& ctx) { - paddle::small_vector in_names = {"Param", - "Grad", - "LearningRate", - "Moment1", - "Moment2", - "Beta1Pow", - "Beta2Pow", - "MasterParam", - "SkipUpdate"}; - paddle::small_vector out_names = {"ParamOut", - "Moment1Out", - "Moment2Out", - "Beta1PowOut", - "Beta2PowOut", - "MasterParamOut"}; - paddle::small_vector attr_names; - - attr_names.emplace_back("weight_decay"); - attr_names.emplace_back("beta1"); - attr_names.emplace_back("beta2"); - attr_names.emplace_back("epsilon"); - attr_names.emplace_back("multi_precision"); - - if (ctx.IsSelectedRowsInput("Grad")) { - return KernelSignature("lamb_sr", - std::move(in_names), - std::move(attr_names), - std::move(out_names)); - } else if (ctx.IsDenseTensorInput("Grad")) { - return KernelSignature("lamb", - std::move(in_names), - std::move(attr_names), - std::move(out_names)); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(lamb, phi::LambOpArgumentMapping); diff --git a/paddle/phi/ops/compat/sgd_sig.cc b/paddle/phi/ops/compat/sgd_sig.cc deleted file mode 100644 index cdf1a221f7ec2aa24ead046ce9e05724c4278d38..0000000000000000000000000000000000000000 --- a/paddle/phi/ops/compat/sgd_sig.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.IsDenseTensorInput("Grad")) { - return KernelSignature("sgd", - {"Param", "LearningRate", "Grad", "MasterParam"}, - {"multi_precision"}, - {"ParamOut", "MasterParamOut"}); - } else if (ctx.IsSelectedRowsInput("Grad")) { - if (ctx.IsDenseTensorInput("Param")) { - return KernelSignature("sgd_dense_param_sparse_grad", - {"Param", "LearningRate", "Grad", "MasterParam"}, - {"multi_precision"}, - {"ParamOut", "MasterParamOut"}); - } else { - return KernelSignature("sgd_sparse_param_sparse_grad", - {"Param", "LearningRate", "Grad", "MasterParam"}, - {"multi_precision"}, - {"ParamOut", "MasterParamOut"}); - } - } - - return KernelSignature("unregistered", {}, {}, {}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(sgd, phi::SGDOpArgumentMapping); diff --git a/test/cpp/new_executor/CMakeLists.txt b/test/cpp/new_executor/CMakeLists.txt index 577affae884f942864a583b9ab003c4cb532d373..11e4e9a84e1820a60b92e1cb3c51e03b42c6d273 100644 --- a/test/cpp/new_executor/CMakeLists.txt +++ b/test/cpp/new_executor/CMakeLists.txt @@ -30,7 +30,7 @@ if(WITH_GPU sum_op elementwise_max_op elementwise_div_op - sgd_op + generated_op squared_l2_norm_op memcpy_h2d_op memcpy_d2h_op