未验证 提交 0bcfc474 编写于 作者: H hong 提交者: GitHub

fix eager gen opti bug (#41302)

* fix eager gen opti bug

* polish code

* fix some bug

* fix some bugs;
上级 49e4e2f9
......@@ -359,6 +359,12 @@ CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = \
if({}.initialized()) {}_optional = paddle::make_optional<const paddle::experimental::Tensor&>({});
"""
CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE = \
"""
paddle::optional<const paddle::experimental::Tensor&> {}_optional = paddle::none;
if( {}.impl() ) {}_optional = paddle::make_optional<const paddle::experimental::Tensor&>({});
"""
#######################
## Generator Helpers ##
......@@ -1248,11 +1254,18 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
name)
is_optional = (name in self.optional_inputs)
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
if is_optional:
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
tensor_wrapper_recover_str += "\n" + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name,
transformed_tensor_name, transformed_tensor_name)
grad_api_args[
grad_api_position] = transformed_tensor_name + "_optional"
else:
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
grad_api_args[grad_api_position] = transformed_tensor_name
grad_api_args[grad_api_position] = transformed_tensor_name
get_grad_in_args_list.append(tensor_wrapper_recover_str)
# Grad Ins from grads
......
......@@ -364,22 +364,6 @@ paddle::experimental::Tensor EagerUtils::RecoverTensorWrapper(
return tw->recover(grad_node);
}
paddle::optional<const paddle::experimental::Tensor&>
EagerUtils::RecoverOptionalTensorWrapper(
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node) {
PADDLE_ENFORCE_NOT_NULL(
tw, phi::errors::InvalidArgument("TensorWrapper in "
"RecoverOptionalTensorWrapper function "
"should not be null"));
auto tmp = tw->recover(grad_node);
paddle::optional<const paddle::experimental::Tensor&> res{paddle::none};
if (tmp.initialized()) {
res = tmp;
}
return res;
}
std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper(
std::vector<TensorWrapper>* tw,
const std::shared_ptr<GradNodeBase>& grad_node) {
......
......@@ -179,9 +179,6 @@ class EagerUtils {
static std::vector<paddle::experimental::Tensor> RecoverTensorWrapper(
std::vector<TensorWrapper>* tw,
const std::shared_ptr<GradNodeBase>& grad_node);
static paddle::optional<const paddle::experimental::Tensor&>
RecoverOptionalTensorWrapper(TensorWrapper* tw,
const std::shared_ptr<GradNodeBase>& grad_node);
// Intermidate needed remove this once we don't need legacy
// Inner Method
......
......@@ -971,7 +971,7 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
std::vector<int> value = CastPyArg2Ints(obj, op_type, arg_pos);
return paddle::experimental::IntArray(value);
} else if (type_name == "paddle.Tensor") {
} else if (type_name == "paddle.Tensor" || type_name == "Tensor") {
paddle::experimental::Tensor& value = GetTensorFromPyObject(
op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/);
return paddle::experimental::IntArray(value);
......
......@@ -567,7 +567,7 @@ class PADDLE_API Tensor final {
* heterogeneous Tensor implementation, so that the API level can be unified
* to one `Tensor`.
*/
std::shared_ptr<phi::TensorBase> impl_;
std::shared_ptr<phi::TensorBase> impl_{nullptr};
/**
* [ Why need abstract AbstractAutogradMeta here? ]
......
......@@ -66,14 +66,6 @@ phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor) {
return phi::MetaTensor(tensor);
}
paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::DenseTensor&>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return {paddle::none};
}
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::DenseTensor*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
......@@ -88,14 +80,6 @@ phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) {
return phi::MetaTensor(tensor);
}
paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::SelectedRows&>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return {paddle::none};
}
phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor) {
return phi::MetaTensor(tensor);
}
......
......@@ -50,17 +50,11 @@ std::shared_ptr<phi::StringTensor> TensorToStringTensor(const Tensor& tensor);
phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor);
paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::DenseTensor&>& tensor);
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::DenseTensor*>& tensors);
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor);
paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::SelectedRows&>& tensor);
phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor);
/* ------------------ for output ----------------------- */
......
......@@ -480,11 +480,15 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
param_code = param_code + param + "_metas, "
elif param in self.optional_vars:
meta_tensor_code = meta_tensor_code + f"""
{code_indent} paddle::optional<const phi::MetaTensor&> {PREFIX_TENSOR_NAME}meta_ref_{param}(paddle::none);
{code_indent} auto {PREFIX_TENSOR_NAME}meta_{param} = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
{code_indent} if ({PREFIX_TENSOR_NAME}meta_{param}) {{
{code_indent} {PREFIX_TENSOR_NAME}meta_ref_{param} = paddle::make_optional<const phi::MetaTensor&>(*{PREFIX_TENSOR_NAME}meta_{param});
{code_indent} }}"""
{code_indent} paddle::optional<const phi::MetaTensor&> {PREFIX_TENSOR_NAME}meta_ref_{param} = paddle::none;
{code_indent} phi::DenseTensor dt;
{code_indent} phi::MetaTensor {PREFIX_TENSOR_NAME}meta_tmp_{param}(dt);
{code_indent} if ({PREFIX_TENSOR_NAME}{param}_ptr) {{
{code_indent} {PREFIX_TENSOR_NAME}meta_tmp_{param}.set_dtype( {PREFIX_TENSOR_NAME}{param}_ptr->dtype() );
{code_indent} {PREFIX_TENSOR_NAME}meta_tmp_{param}.set_dims( {PREFIX_TENSOR_NAME}{param}_ptr->dims() );
{code_indent} {PREFIX_TENSOR_NAME}meta_tmp_{param}.set_layout( {PREFIX_TENSOR_NAME}{param}_ptr->layout() );
{code_indent} {PREFIX_TENSOR_NAME}meta_ref_{param} = {PREFIX_TENSOR_NAME}meta_tmp_{param};
{code_indent} }}\n"""
param_code = param_code + f"{PREFIX_TENSOR_NAME}meta_ref_{param}, "
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册