未验证 提交 d13a4a25 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Fix ocr (#46124)

* fix linspace error in amp

* fix log

* fix amp error

* fix ocr error which caused by amp

* add more check

* rename dtype ns
上级 56f9452c
...@@ -1797,6 +1797,15 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1797,6 +1797,15 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
generated_function_body += amp_context; generated_function_body += amp_context;
generated_function_body += "\n"; generated_function_body += "\n";
} }
if (!forward_inplace_map.empty()) {
generated_function_body +=
" auto current_level = egr::Controller::Instance().GetAMPLevel();\n";
generated_function_body +=
" "
"egr::Controller::Instance().SetAMPLevel(paddle::imperative::AmpLevel::"
"O0);\n";
}
// forward ins insert // forward ins insert
const char* FWD_INS_MAP_TEMPLATE = const char* FWD_INS_MAP_TEMPLATE =
" std::map<std::string, " " std::map<std::string, "
...@@ -1999,6 +2008,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1999,6 +2008,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
} }
trace_op_body_str += out_tensor_str; trace_op_body_str += out_tensor_str;
} }
if (!forward_inplace_map.empty()) {
trace_op_body_str +=
" egr::Controller::Instance().SetAMPLevel(current_level);\n";
}
trace_op_body_str += "\n"; trace_op_body_str += "\n";
VLOG(6) << "Converted Output VarBase to EagerVariable(s)"; VLOG(6) << "Converted Output VarBase to EagerVariable(s)";
/* ------ END Generate TraceOp ----- */ /* ------ END Generate TraceOp ----- */
......
...@@ -47,7 +47,9 @@ typedef SSIZE_T ssize_t; ...@@ -47,7 +47,9 @@ typedef SSIZE_T ssize_t;
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
...@@ -1171,6 +1173,17 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, ...@@ -1171,6 +1173,17 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
// Release gil and do tracing // Release gil and do tracing
py::gil_scoped_release release; py::gil_scoped_release release;
// use inplace set_value_ operator // use inplace set_value_ operator
if (value_tensor.initialized() &&
(self->tensor.dtype() != value_tensor.dtype())) {
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
self->tensor = set_value__dygraph_function( self->tensor = set_value__dygraph_function(
self->tensor, value_tensor, {}, {}, {}, attrs); self->tensor, value_tensor, {}, {}, {}, attrs);
} }
......
...@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(set_value_grad, ...@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(set_value_grad,
double, double,
int, int,
int64_t, int64_t,
bool) {} bool,
phi::dtype::float16) {}
...@@ -26,7 +26,8 @@ PD_REGISTER_KERNEL(set_value, ...@@ -26,7 +26,8 @@ PD_REGISTER_KERNEL(set_value,
double, double,
int, int,
int64_t, int64_t,
bool) {} bool,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(set_value_with_tensor, PD_REGISTER_KERNEL(set_value_with_tensor,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -35,4 +36,5 @@ PD_REGISTER_KERNEL(set_value_with_tensor, ...@@ -35,4 +36,5 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
double, double,
int, int,
int64_t, int64_t,
bool) {} bool,
phi::dtype::float16) {}
...@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(set_value_grad, ...@@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(set_value_grad,
double, double,
int, int,
int64_t, int64_t,
bool) {} bool,
phi::dtype::float16) {}
...@@ -26,7 +26,8 @@ PD_REGISTER_KERNEL(set_value, ...@@ -26,7 +26,8 @@ PD_REGISTER_KERNEL(set_value,
double, double,
int, int,
int64_t, int64_t,
bool) {} bool,
paddle::platform::float16) {}
PD_REGISTER_KERNEL(set_value_with_tensor, PD_REGISTER_KERNEL(set_value_with_tensor,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -35,4 +36,5 @@ PD_REGISTER_KERNEL(set_value_with_tensor, ...@@ -35,4 +36,5 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
double, double,
int, int,
int64_t, int64_t,
bool) {} bool,
paddle::platform::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册