未验证 提交 01bfe786 编写于 作者: W WangZhen 提交者: GitHub

Get three grad lists in CPP to avoid gpu idle time (#47665)

* Get three grad lists in CPP to avoid gpu idle time

* Support legacy mode
上级 0b3b4918
......@@ -190,6 +190,41 @@ static PyObject* eager_api_tensor_copy(PyObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyObject* eager_api_get_grads_lists(PyObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);
std::vector<std::vector<paddle::experimental::Tensor>> ret(3);
for (auto& tensor : tensor_list) {
VLOG(6) << "Get grad for tensor: " << tensor.name();
auto meta = egr::EagerUtils::nullable_autograd_meta(tensor);
VLOG(6) << meta << " initialized: " << meta->Grad().initialized();
if (meta && meta->Grad().initialized()) {
auto& grad = meta->Grad();
switch (grad.dtype()) {
case paddle::experimental::DataType::FLOAT16:
ret[0].emplace_back(grad);
break;
case paddle::experimental::DataType::BFLOAT16:
ret[1].emplace_back(grad);
break;
case paddle::experimental::DataType::FLOAT32:
ret[2].emplace_back(grad);
break;
default:
break;
}
}
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* eager_api_read_next_tensor_list(PyObject* self,
PyObject* args,
PyObject* kwargs) {
......@@ -1001,6 +1036,10 @@ PyMethodDef variable_functions[] = {
(PyCFunction)(void (*)(void))eager_api_tensor_copy,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"get_grads_lists",
(PyCFunction)(void (*)(void))eager_api_get_grads_lists,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"read_next_tensor_list",
(PyCFunction)(void (*)(void))eager_api_read_next_tensor_list,
METH_VARARGS | METH_KEYWORDS,
......
......@@ -720,6 +720,17 @@ PyObject* ToPyObject(const std::vector<paddle::experimental::Tensor>& value,
return result;
}
PyObject* ToPyObject(
const std::vector<std::vector<paddle::experimental::Tensor>>& value) {
PyObject* result = PyList_New((Py_ssize_t)value.size());
for (size_t i = 0; i < value.size(); i++) {
PyList_SET_ITEM(result, static_cast<Py_ssize_t>(i), ToPyObject(value[i]));
}
return result;
}
PyObject* ToPyObject(const platform::Place& value) {
auto obj = ::pybind11::cast(value);
obj.inc_ref();
......
......@@ -103,6 +103,8 @@ PyObject* ToPyObject(const std::vector<double>& value);
PyObject* ToPyObject(const std::vector<std::vector<size_t>>& value);
PyObject* ToPyObject(const std::vector<paddle::experimental::Tensor>& value,
bool return_py_none_if_not_initialize = false);
PyObject* ToPyObject(
const std::vector<std::vector<paddle::experimental::Tensor>>& value);
PyObject* ToPyObject(const platform::Place& value);
PyObject* ToPyObject(const phi::DenseTensor* value);
PyObject* ToPyObject(const phi::SelectedRows* value);
......
......@@ -26,6 +26,7 @@ import numpy as np
from paddle import _C_ops, _legacy_C_ops
from collections import defaultdict
from enum import Enum
from paddle.fluid import in_dygraph_mode
__all__ = ['AmpScaler', 'OptimizerState']
......@@ -297,26 +298,33 @@ class AmpScaler(object):
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar()
for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP16
]
param_grads_bf16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.BF16
]
param_grads_fp32 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP32
]
if in_dygraph_mode():
(
param_grads_fp16,
param_grads_bf16,
param_grads_fp32,
) = core.eager.get_grads_lists(optimizer._parameter_list)
else:
param_grads = [
param._grad_ivar()
for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP16
]
param_grads_bf16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.BF16
]
param_grads_fp32 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP32
]
if core.is_compiled_with_npu():
float_status = _legacy_C_ops.alloc_float_status()
_legacy_C_ops.clear_float_status(float_status, float_status)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册