未验证 提交 31909bb5 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] support final_state_full_ under eager (#44806)

* [Eager] use final_state_fill_constant_

* fill_constant use str_value

* add fill_constant_ to no_amp_list

* use float(value) as input

* support final state full_ same as fill_constant
上级 34b43555
...@@ -54,45 +54,15 @@ atype_to_parsing_function = { ...@@ -54,45 +54,15 @@ atype_to_parsing_function = {
# This list contains ops that do not need to generate amp logic # This list contains ops that do not need to generate amp logic
# All optimizer ops in this list # All optimizer ops in this list
no_amp_list = [ no_amp_list = [
'adam_', 'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates',
'adam', 'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad',
'adamw_', 'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_',
'adamw', 'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_',
'average_accumulates', 'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam',
'average_accumulates_', 'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum',
'decayed_adagrad_', 'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd',
'decayed_adagrad', 'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_',
'dgc_momentum_', 'sparse_momentum_', 'sparse_momentum', 'full_'
'dgc_momentum',
'distributed_fused_lamb_',
'distributed_fused_lamb',
'dpsgd_',
'dpsgd',
'ftrl_',
'ftrl',
'lamb_',
'lamb',
'lars_momentum_',
'lars_momentum',
'merged_adam_',
'merged_adam',
'merged_momentum_',
'merged_momentum',
'momentum_',
'momentum',
'proximal_adagrad_',
'proximal_adagrad',
'proximal_gd_',
'proximal_gd',
'rmsprop_',
'rmsprop',
'sgd_',
'sgd',
'lamb_',
'lamb',
'assign_value_',
'sparse_momentum_',
'sparse_momentum',
] ]
......
...@@ -1235,6 +1235,9 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, ...@@ -1235,6 +1235,9 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
} else if (PyObject_CheckLongOrToLong(&obj)) { } else if (PyObject_CheckLongOrToLong(&obj)) {
int value = CastPyArg2Int(obj, op_type, arg_pos); int value = CastPyArg2Int(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value); return paddle::experimental::Scalar(value);
} else if (PyObject_CheckString(obj)) {
std::string value = CastPyArg2String(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
......
...@@ -18,6 +18,8 @@ import argparse ...@@ -18,6 +18,8 @@ import argparse
from api_gen import ForwardAPI from api_gen import ForwardAPI
kernel_func_set = set()
def get_wrapped_infermeta_name(api_name): def get_wrapped_infermeta_name(api_name):
return api_name.capitalize() + 'InferMeta' return api_name.capitalize() + 'InferMeta'
...@@ -29,6 +31,9 @@ def gene_wrapped_infermeta_and_register(api): ...@@ -29,6 +31,9 @@ def gene_wrapped_infermeta_and_register(api):
PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']});""" PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']});"""
if api.infer_meta['param'] is not None: if api.infer_meta['param'] is not None:
if api.kernel['func'][0] in kernel_func_set:
return '', '', ''
kernel_params = api.kernel['param'] kernel_params = api.kernel['param']
if kernel_params is None: if kernel_params is None:
kernel_params = api.inputs['names'] + api.attrs['names'] kernel_params = api.inputs['names'] + api.attrs['names']
...@@ -78,6 +83,7 @@ void {wrapped_infermeta_name}({", ".join(args)}) {{ ...@@ -78,6 +83,7 @@ void {wrapped_infermeta_name}({", ".join(args)}) {{
register_code = f""" register_code = f"""
PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{get_wrapped_infermeta_name(api.kernel['func'][0])});""" PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{get_wrapped_infermeta_name(api.kernel['func'][0])});"""
kernel_func_set.add(api.kernel['func'][0])
return declare_code, defind_code, register_code return declare_code, defind_code, register_code
else: else:
return '', '', register_code return '', '', register_code
......
...@@ -975,6 +975,20 @@ ...@@ -975,6 +975,20 @@
data_type : dtype data_type : dtype
backend : place backend : place
# full
- api : full_
args : (Tensor output, IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output : Tensor(out)
inplace : (output -> out)
infer_meta :
func : CreateInferMeta
param : [shape, dtype]
kernel :
func : full
param : [shape, value, dtype]
data_type : dtype
backend : place
- api : full_batch_size_like - api : full_batch_size_like
args : (Tensor input, int[] shape, DataType dtype, Scalar value, int input_dim_idx, int output_dim_idx, Place place=CPUPlace()) args : (Tensor input, int[] shape, DataType dtype, Scalar value, int input_dim_idx, int output_dim_idx, Place place=CPUPlace())
output: Tensor output: Tensor
......
...@@ -138,14 +138,18 @@ class ConstantInitializer(Initializer): ...@@ -138,14 +138,18 @@ class ConstantInitializer(Initializer):
or isinstance(var, framework.EagerParamBase)) or isinstance(var, framework.EagerParamBase))
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
if framework._non_static_mode(): if in_dygraph_mode():
place = _current_expected_place()
_C_ops.final_state_full_(var, var.shape, str(float(self._value)),
var.dtype, place)
return None
elif _in_legacy_dygraph():
_C_ops.fill_constant(var, 'value', float(self._value), _C_ops.fill_constant(var, 'value', float(self._value),
'force_cpu', self._force_cpu, 'dtype', 'force_cpu', self._force_cpu, 'dtype',
int(var.dtype), 'str_value', int(var.dtype), 'str_value',
str(float(self._value)), 'shape', var.shape) str(float(self._value)), 'shape', var.shape)
return None return None
else: else:
# fill constant should set the "str_value" to preserve precision
op = block.append_op(type="fill_constant", op = block.append_op(type="fill_constant",
outputs={"Out": var}, outputs={"Out": var},
attrs={ attrs={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册