未验证 提交 4d886f75 编写于 作者: X xiongkun 提交者: GitHub

run python api in eager model and filter the out in argument list (#40523)

* run python api in eager model and filter the out in argument list

* fix code
上级 30417999
...@@ -390,8 +390,8 @@ bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins, ...@@ -390,8 +390,8 @@ bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins,
} }
phi::KernelSignature Tracer::GetExpectedKernelSignature( phi::KernelSignature Tracer::GetExpectedKernelSignature(
const std::string& type, const NameVarBaseMap& ins, const std::string& type, const NameTensorMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs) const { const NameTensorMap& outs, framework::AttributeMap attrs) const {
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
framework::RuntimeContext ctx({}, {}); framework::RuntimeContext ctx({}, {});
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
...@@ -406,7 +406,7 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature( ...@@ -406,7 +406,7 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature(
attr_checker == nullptr ? empty_attrs_map attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap(); : attr_checker->GetDefaultAttrMap();
auto dygraph_exe_ctx = auto dygraph_exe_ctx =
imperative::DygraphExecutionContext<imperative::VarBase>( imperative::DygraphExecutionContext<egr::EagerVariable>(
*op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, *op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs,
default_attrs); default_attrs);
auto* opbase_with_kernel = auto* opbase_with_kernel =
......
...@@ -156,8 +156,8 @@ class Tracer { ...@@ -156,8 +156,8 @@ class Tracer {
} }
phi::KernelSignature GetExpectedKernelSignature( phi::KernelSignature GetExpectedKernelSignature(
const std::string& type, const NameVarBaseMap& ins, const std::string& type, const NameTensorMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs) const; const NameTensorMap& outs, framework::AttributeMap attrs) const;
paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place); const platform::Place& place);
......
...@@ -52,11 +52,13 @@ limitations under the License. */ ...@@ -52,11 +52,13 @@ limitations under the License. */
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/op_function.h" #include "paddle/fluid/pybind/op_function.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/pybind_boost_headers.h"
#include "paddle/fluid/pybind/slice_utils.h" #include "paddle/fluid/pybind/slice_utils.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/type_defs.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -436,6 +438,28 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap( ...@@ -436,6 +438,28 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
return result; return result;
} }
paddle::imperative::NameTensorMap ConvertToNameTensorMap(
const PyNameVarBaseMap &map) {
paddle::imperative::NameTensorMap result;
for (auto &pair : map) {
auto var_vec = CastPyArg2VectorOfTensor(pair.second.ptr(), 0);
if (!var_vec.empty()) {
// change vector<Tensor> -> vector<shared_ptr<Tensor>>
std::vector<std::shared_ptr<egr::EagerVariable>> dst_var_vec;
for (auto &v : var_vec) {
dst_var_vec.emplace_back(
std::make_shared<egr::EagerVariable>(std::move(v)));
}
result.emplace(pair.first, std::move(dst_var_vec));
}
}
PADDLE_ENFORCE_EQ(
PyErr_Occurred(), nullptr,
platform::errors::InvalidArgument(py::str(py::handle(PyErr_Occurred()))));
return result;
}
template <typename P> template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, // NOLINT static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, // NOLINT
imperative::VarBase &dst, // NOLINT imperative::VarBase &dst, // NOLINT
...@@ -2079,8 +2103,8 @@ void BindImperative(py::module *m_ptr) { ...@@ -2079,8 +2103,8 @@ void BindImperative(py::module *m_ptr) {
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs) { framework::AttributeMap attrs) {
// TODO(xiongkun): move this function outside of tracer. // TODO(xiongkun): move this function outside of tracer.
auto ins_map = ConvertToNameVarBaseMap(ins); auto ins_map = ConvertToNameTensorMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameTensorMap(outs);
{ {
auto to_vector = [](paddle::SmallVector<std::string> &vec) { auto to_vector = [](paddle::SmallVector<std::string> &vec) {
return std::vector<std::string>(vec.begin(), vec.end()); return std::vector<std::string>(vec.begin(), vec.end());
......
...@@ -715,10 +715,11 @@ class OpTest(unittest.TestCase): ...@@ -715,10 +715,11 @@ class OpTest(unittest.TestCase):
assert related_idx >= 0, "%d-th arguments don't have default value" % idx assert related_idx >= 0, "%d-th arguments don't have default value" % idx
return defaults[related_idx] return defaults[related_idx]
def remove_name(x): def filter_by_name(x):
if isinstance(x, list): return [i for i in x if i != 'name'] names = set(['name', 'out', 'output'])
if isinstance(x, list): return [i for i in x if i not in names]
if isinstance(x, dict): if isinstance(x, dict):
return {k: v for k, v in x.items() if k != 'name'} return {k: v for k, v in x.items() if k not in names}
assert False, "Only support list or dict." assert False, "Only support list or dict."
def to_defaults_list(params, defaults): def to_defaults_list(params, defaults):
...@@ -728,7 +729,7 @@ class OpTest(unittest.TestCase): ...@@ -728,7 +729,7 @@ class OpTest(unittest.TestCase):
# Because we don't know the python api name of each arguments. # Because we don't know the python api name of each arguments.
# using parse_arg_and_kwargs, we can get the all api information we need. # using parse_arg_and_kwargs, we can get the all api information we need.
api_params, api_defaults = [ api_params, api_defaults = [
remove_name(item) for item in parse_arg_and_kwargs(api) filter_by_name(item) for item in parse_arg_and_kwargs(api)
] ]
api_defaults = to_defaults_list(api_params, api_defaults) api_defaults = to_defaults_list(api_params, api_defaults)
inputs_sig, attrs_sig, outputs_sig = kernel_sig inputs_sig, attrs_sig, outputs_sig = kernel_sig
...@@ -784,10 +785,10 @@ class OpTest(unittest.TestCase): ...@@ -784,10 +785,10 @@ class OpTest(unittest.TestCase):
block = fluid.default_main_program().global_block() block = fluid.default_main_program().global_block()
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
# prepare input variable # prepare input variable
inputs = self.append_input_output_for_dygraph(op_proto, self.inputs, eager_tensor_inputs = self.append_input_output_for_dygraph(
True, False, block) op_proto, self.inputs, True, False, block)
# prepare output variable # prepare output variable
outputs = self.append_input_output_for_dygraph( eager_tensor_outputs = self.append_input_output_for_dygraph(
op_proto, self.outputs, False, False, block) op_proto, self.outputs, False, False, block)
# prepare attrbutes # prepare attrbutes
...@@ -798,13 +799,14 @@ class OpTest(unittest.TestCase): ...@@ -798,13 +799,14 @@ class OpTest(unittest.TestCase):
attrs_outputs[attrs_name] = self.attrs[attrs_name] attrs_outputs[attrs_name] = self.attrs[attrs_name]
kernel_sig = _dygraph_tracer()._get_kernel_signature( kernel_sig = _dygraph_tracer()._get_kernel_signature(
self.op_type, inputs, outputs, attrs_outputs) self.op_type, eager_tensor_inputs, eager_tensor_outputs,
attrs_outputs)
assert hasattr( assert hasattr(
self, "python_api" self, "python_api"
), "Please set the `self.python_api` if you want to compare python api output." ), "Please set the `self.python_api` if you want to compare python api output."
args = prepare_python_api_arguments(self.python_api, inputs, args = prepare_python_api_arguments(
attrs_outputs, kernel_sig) self.python_api, eager_tensor_inputs, attrs_outputs, kernel_sig)
""" we directly return the cal_python_api value because the value is already tensor. """ we directly return the cal_python_api value because the value is already tensor.
""" """
return cal_python_api(self.python_api, args, kernel_sig) return cal_python_api(self.python_api, args, kernel_sig)
...@@ -1286,11 +1288,11 @@ class OpTest(unittest.TestCase): ...@@ -1286,11 +1288,11 @@ class OpTest(unittest.TestCase):
with _test_eager_guard(): with _test_eager_guard():
eager_dygraph_outs = self._calc_dygraph_output( eager_dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set) place, no_check_set=no_check_set)
# we only check end2end api when check_eager=True # we only check end2end api when check_eager=True
if hasattr(self, "python_api"): if hasattr(self, "python_api"):
api_outs = self._calc_python_api_output(place) api_outs = self._calc_python_api_output(place)
self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs, self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs,
place) place)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set) outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册