未验证 提交 d13a4a25 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Fix ocr (#46124)

* fix linspace error in amp

* fix log

* fix amp error

* fix ocr error which caused by amp

* add more check

* rename dtype ns
上级 56f9452c
......@@ -1797,6 +1797,15 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
generated_function_body += amp_context;
generated_function_body += "\n";
}
if (!forward_inplace_map.empty()) {
generated_function_body +=
" auto current_level = egr::Controller::Instance().GetAMPLevel();\n";
generated_function_body +=
" "
"egr::Controller::Instance().SetAMPLevel(paddle::imperative::AmpLevel::"
"O0);\n";
}
// forward ins insert
const char* FWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
......@@ -1999,6 +2008,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
}
trace_op_body_str += out_tensor_str;
}
if (!forward_inplace_map.empty()) {
trace_op_body_str +=
" egr::Controller::Instance().SetAMPLevel(current_level);\n";
}
trace_op_body_str += "\n";
VLOG(6) << "Converted Output VarBase to EagerVariable(s)";
/* ------ END Generate TraceOp ----- */
......
......@@ -47,7 +47,9 @@ typedef SSIZE_T ssize_t;
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/pybind/tensor_py.h"
......@@ -1171,6 +1173,17 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
// Release gil and do tracing
py::gil_scoped_release release;
// use inplace set_value_ operator
if (value_tensor.initialized() &&
(self->tensor.dtype() != value_tensor.dtype())) {
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
self->tensor = set_value__dygraph_function(
self->tensor, value_tensor, {}, {}, {}, attrs);
}
......
......@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(set_value_grad,
double,
int,
int64_t,
bool) {}
bool,
phi::dtype::float16) {}
......@@ -26,7 +26,8 @@ PD_REGISTER_KERNEL(set_value,
double,
int,
int64_t,
bool) {}
bool,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
CPU,
ALL_LAYOUT,
......@@ -35,4 +36,5 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
double,
int,
int64_t,
bool) {}
bool,
phi::dtype::float16) {}
......@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(set_value_grad,
double,
int,
int64_t,
bool) {}
bool,
phi::dtype::float16) {}
......@@ -26,7 +26,8 @@ PD_REGISTER_KERNEL(set_value,
double,
int,
int64_t,
bool) {}
bool,
paddle::platform::float16) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
GPU,
ALL_LAYOUT,
......@@ -35,4 +36,5 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
double,
int,
int64_t,
bool) {}
bool,
paddle::platform::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册