未验证 提交 b96c7c9a 编写于 作者: L Leo Chen 提交者: GitHub

polish code, test=develop (#22380)

remove unnecessary template.
上级 a392b777
......@@ -87,10 +87,7 @@ class Tracer {
platform::Place ExpectedPlace() const { return expected_place_; }
template <typename PlaceType>
void SetExpectedPlace(PlaceType place) {
expected_place_ = place;
}
void SetExpectedPlace(platform::Place place) { expected_place_ = place; }
bool NoGrad() const { return no_grad_; }
......
......@@ -536,18 +536,18 @@ void BindImperative(py::module *m_ptr) {
[](imperative::Tracer &self, const py::object &obj) {
if (py::isinstance<platform::CUDAPlace>(obj)) {
auto p = obj.cast<platform::CUDAPlace *>();
self.SetExpectedPlace<platform::CUDAPlace>(*p);
self.SetExpectedPlace(*p);
} else if (py::isinstance<platform::CPUPlace>(obj)) {
auto p = obj.cast<platform::CPUPlace *>();
self.SetExpectedPlace<platform::CPUPlace>(*p);
self.SetExpectedPlace(*p);
} else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
auto p = obj.cast<platform::CUDAPinnedPlace *>();
self.SetExpectedPlace<platform::CUDAPinnedPlace>(*p);
self.SetExpectedPlace(*p);
} else {
PADDLE_THROW(
PADDLE_THROW(platform::errors::InvalidArgument(
"Incompatible Place Type: supports CUDAPlace, CPUPlace, "
"CUDAPinnedPlace, "
"but got Unknown Type!");
"and CUDAPinnedPlace, "
"but got Unknown Type!"));
}
})
.def("_get_program_desc_tracer",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册