From 31909bb5e3af20aaec5f529dbe29a5a8c0d61e37 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Tue, 9 Aug 2022 19:04:02 +0800 Subject: [PATCH] [Eager] support final_state_full_ under eager (#44806) * [Eager] use final_state_fill_constant_ * fill_constant use str_value * add fill_constant_ to no_amp_list * use float(value) as input * support final state full_ same as fill_constant --- .../final_state_generator/python_c_gen.py | 48 ++++--------------- paddle/fluid/pybind/eager_utils.cc | 3 ++ .../yaml/generator/wrapped_infermeta_gen.py | 6 +++ paddle/phi/api/yaml/legacy_api.yaml | 14 ++++++ python/paddle/fluid/initializer.py | 8 +++- 5 files changed, 38 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index f0745bd9690..8fde6951e03 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -54,45 +54,15 @@ atype_to_parsing_function = { # This list contains ops that do not need to generate amp logic # All optimizer ops in this list no_amp_list = [ - 'adam_', - 'adam', - 'adamw_', - 'adamw', - 'average_accumulates', - 'average_accumulates_', - 'decayed_adagrad_', - 'decayed_adagrad', - 'dgc_momentum_', - 'dgc_momentum', - 'distributed_fused_lamb_', - 'distributed_fused_lamb', - 'dpsgd_', - 'dpsgd', - 'ftrl_', - 'ftrl', - 'lamb_', - 'lamb', - 'lars_momentum_', - 'lars_momentum', - 'merged_adam_', - 'merged_adam', - 'merged_momentum_', - 'merged_momentum', - 'momentum_', - 'momentum', - 'proximal_adagrad_', - 'proximal_adagrad', - 'proximal_gd_', - 'proximal_gd', - 'rmsprop_', - 'rmsprop', - 'sgd_', - 'sgd', - 'lamb_', - 'lamb', - 'assign_value_', - 'sparse_momentum_', - 'sparse_momentum', + 'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates', + 'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad', + 'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_', + 'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_', + 'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam', + 'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum', + 'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd', + 'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_', + 'sparse_momentum_', 'sparse_momentum', 'full_' ] diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index a92ddf388c2..6c1dea40b78 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1235,6 +1235,9 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, } else if (PyObject_CheckLongOrToLong(&obj)) { int value = CastPyArg2Int(obj, op_type, arg_pos); return paddle::experimental::Scalar(value); + } else if (PyObject_CheckString(obj)) { + std::string value = CastPyArg2String(obj, op_type, arg_pos); + return paddle::experimental::Scalar(value); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " diff --git a/paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py b/paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py index dfa6a7f93cb..0504d3fd108 100644 --- a/paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py +++ b/paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py @@ -18,6 +18,8 @@ import argparse from api_gen import ForwardAPI +kernel_func_set = set() + def get_wrapped_infermeta_name(api_name): return api_name.capitalize() + 'InferMeta' @@ -29,6 +31,9 @@ def gene_wrapped_infermeta_and_register(api): PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']});""" if api.infer_meta['param'] is not None: + if api.kernel['func'][0] in kernel_func_set: + return '', '', '' + kernel_params = api.kernel['param'] if kernel_params is None: kernel_params = api.inputs['names'] + api.attrs['names'] @@ -78,6 +83,7 @@ void {wrapped_infermeta_name}({", ".join(args)}) {{ register_code = f""" PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{get_wrapped_infermeta_name(api.kernel['func'][0])});""" + kernel_func_set.add(api.kernel['func'][0]) return declare_code, defind_code, register_code else: return '', '', register_code diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index e5a808e97b0..e90250901dc 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -975,6 +975,20 @@ data_type : dtype backend : place +# full +- api : full_ + args : (Tensor output, IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) + output : Tensor(out) + inplace : (output -> out) + infer_meta : + func : CreateInferMeta + param : [shape, dtype] + kernel : + func : full + param : [shape, value, dtype] + data_type : dtype + backend : place + - api : full_batch_size_like args : (Tensor input, int[] shape, DataType dtype, Scalar value, int input_dim_idx, int output_dim_idx, Place place=CPUPlace()) output: Tensor diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 9174a72eed9..b4c99a7af49 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -138,14 +138,18 @@ class ConstantInitializer(Initializer): or isinstance(var, framework.EagerParamBase)) assert isinstance(block, framework.Block) - if framework._non_static_mode(): + if in_dygraph_mode(): + place = _current_expected_place() + _C_ops.final_state_full_(var, var.shape, str(float(self._value)), + var.dtype, place) + return None + elif _in_legacy_dygraph(): _C_ops.fill_constant(var, 'value', float(self._value), 'force_cpu', self._force_cpu, 'dtype', int(var.dtype), 'str_value', str(float(self._value)), 'shape', var.shape) return None else: - # fill constant should set the "str_value" to preserve precision op = block.append_op(type="fill_constant", outputs={"Out": var}, attrs={ -- GitLab