未验证 提交 c4204bbf 编写于 作者: W wanghuancoder 提交者: GitHub

test close trace (#55787)

* test close old dygraph trace
上级 4ddd6154
......@@ -2284,116 +2284,6 @@ void BindImperative(py::module *m_ptr) {
return std::make_tuple(
kernelsig_ins, kernelsig_attrs, kernelsig_outs);
}
})
.def("trace",
[](imperative::Tracer &self,
const std::string &type,
const PyNameVarBaseMap &ins,
const PyNameVarBaseMap &outs,
framework::AttributeMap attrs,
const platform::CustomPlace &place,
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp<imperative::VarBase>(type,
std::move(ins_map),
std::move(outs_map),
std::move(attrs),
place,
trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self,
const std::string &type,
const PyNameVarBaseMap &ins,
const PyNameVarBaseMap &outs,
framework::AttributeMap attrs,
const platform::XPUPlace &place,
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp<imperative::VarBase>(type,
std::move(ins_map),
std::move(outs_map),
std::move(attrs),
place,
trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self,
const std::string &type,
const PyNameVarBaseMap &ins,
const PyNameVarBaseMap &outs,
framework::AttributeMap attrs,
const platform::CUDAPlace &place,
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp<imperative::VarBase>(type,
std::move(ins_map),
std::move(outs_map),
std::move(attrs),
place,
trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self,
const std::string &type,
const PyNameVarBaseMap &ins,
const PyNameVarBaseMap &outs,
framework::AttributeMap attrs,
const platform::IPUPlace &place,
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp<imperative::VarBase>(type,
std::move(ins_map),
std::move(outs_map),
std::move(attrs),
place,
trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self,
const std::string &type,
const PyNameVarBaseMap &ins,
const PyNameVarBaseMap &outs,
framework::AttributeMap attrs,
const platform::CPUPlace &place,
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp<imperative::VarBase>(type,
std::move(ins_map),
std::move(outs_map),
std::move(attrs),
place,
trace_backward,
inplace_map);
}
});
// define parallel context
......
......@@ -324,23 +324,7 @@ class Tracer(core.Tracer):
type, inputs, outputs, attrs, stop_gradient, inplace_map
)
else:
# this block is used to convert attrs according to the opproto
# since `trace` handles AttributeMap directly, without other
# modification to the passed attribute map, so we change it before
# `trace`
if framework.OpProtoHolder.instance().has_op_proto(type):
proto = framework.OpProtoHolder.instance().get_op_proto(type)
attrs = framework.canonicalize_attrs(attrs, proto)
self.trace(
type,
inputs,
outputs,
attrs,
framework._current_expected_place(),
self._has_grad and not stop_gradient,
inplace_map if inplace_map else {},
)
raise ValueError("trace_op only work in dygraph mode")
def train_mode(self):
self._train_mode = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册