未验证 提交 4d886f75 编写于 作者: X xiongkun 提交者: GitHub

run python api in eager model and filter the out in argument list (#40523)

* run python api in eager model and filter the out in argument list

* fix code
上级 30417999
......@@ -390,8 +390,8 @@ bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins,
}
phi::KernelSignature Tracer::GetExpectedKernelSignature(
const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs) const {
const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs, framework::AttributeMap attrs) const {
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
framework::RuntimeContext ctx({}, {});
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
......@@ -406,7 +406,7 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature(
attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap();
auto dygraph_exe_ctx =
imperative::DygraphExecutionContext<imperative::VarBase>(
imperative::DygraphExecutionContext<egr::EagerVariable>(
*op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs,
default_attrs);
auto* opbase_with_kernel =
......
......@@ -156,8 +156,8 @@ class Tracer {
}
phi::KernelSignature GetExpectedKernelSignature(
const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs) const;
const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs, framework::AttributeMap attrs) const;
paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place);
......
......@@ -52,11 +52,13 @@ limitations under the License. */
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/op_function.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h"
#include "paddle/fluid/pybind/slice_utils.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/type_defs.h"
namespace paddle {
namespace pybind {
......@@ -436,6 +438,28 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
return result;
}
paddle::imperative::NameTensorMap ConvertToNameTensorMap(
const PyNameVarBaseMap &map) {
paddle::imperative::NameTensorMap result;
for (auto &pair : map) {
auto var_vec = CastPyArg2VectorOfTensor(pair.second.ptr(), 0);
if (!var_vec.empty()) {
// change vector<Tensor> -> vector<shared_ptr<Tensor>>
std::vector<std::shared_ptr<egr::EagerVariable>> dst_var_vec;
for (auto &v : var_vec) {
dst_var_vec.emplace_back(
std::make_shared<egr::EagerVariable>(std::move(v)));
}
result.emplace(pair.first, std::move(dst_var_vec));
}
}
PADDLE_ENFORCE_EQ(
PyErr_Occurred(), nullptr,
platform::errors::InvalidArgument(py::str(py::handle(PyErr_Occurred()))));
return result;
}
template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, // NOLINT
imperative::VarBase &dst, // NOLINT
......@@ -2079,8 +2103,8 @@ void BindImperative(py::module *m_ptr) {
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs) {
// TODO(xiongkun): move this function outside of tracer.
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
auto ins_map = ConvertToNameTensorMap(ins);
auto outs_map = ConvertToNameTensorMap(outs);
{
auto to_vector = [](paddle::SmallVector<std::string> &vec) {
return std::vector<std::string>(vec.begin(), vec.end());
......
......@@ -715,10 +715,11 @@ class OpTest(unittest.TestCase):
assert related_idx >= 0, "%d-th arguments don't have default value" % idx
return defaults[related_idx]
def remove_name(x):
if isinstance(x, list): return [i for i in x if i != 'name']
def filter_by_name(x):
names = set(['name', 'out', 'output'])
if isinstance(x, list): return [i for i in x if i not in names]
if isinstance(x, dict):
return {k: v for k, v in x.items() if k != 'name'}
return {k: v for k, v in x.items() if k not in names}
assert False, "Only support list or dict."
def to_defaults_list(params, defaults):
......@@ -728,7 +729,7 @@ class OpTest(unittest.TestCase):
# Because we don't know the python api name of each arguments.
# using parse_arg_and_kwargs, we can get the all api information we need.
api_params, api_defaults = [
remove_name(item) for item in parse_arg_and_kwargs(api)
filter_by_name(item) for item in parse_arg_and_kwargs(api)
]
api_defaults = to_defaults_list(api_params, api_defaults)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
......@@ -784,10 +785,10 @@ class OpTest(unittest.TestCase):
block = fluid.default_main_program().global_block()
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
# prepare input variable
inputs = self.append_input_output_for_dygraph(op_proto, self.inputs,
True, False, block)
eager_tensor_inputs = self.append_input_output_for_dygraph(
op_proto, self.inputs, True, False, block)
# prepare output variable
outputs = self.append_input_output_for_dygraph(
eager_tensor_outputs = self.append_input_output_for_dygraph(
op_proto, self.outputs, False, False, block)
# prepare attrbutes
......@@ -798,13 +799,14 @@ class OpTest(unittest.TestCase):
attrs_outputs[attrs_name] = self.attrs[attrs_name]
kernel_sig = _dygraph_tracer()._get_kernel_signature(
self.op_type, inputs, outputs, attrs_outputs)
self.op_type, eager_tensor_inputs, eager_tensor_outputs,
attrs_outputs)
assert hasattr(
self, "python_api"
), "Please set the `self.python_api` if you want to compare python api output."
args = prepare_python_api_arguments(self.python_api, inputs,
attrs_outputs, kernel_sig)
args = prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, kernel_sig)
""" we directly return the cal_python_api value because the value is already tensor.
"""
return cal_python_api(self.python_api, args, kernel_sig)
......@@ -1286,11 +1288,11 @@ class OpTest(unittest.TestCase):
with _test_eager_guard():
eager_dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set)
# we only check end2end api when check_eager=True
if hasattr(self, "python_api"):
api_outs = self._calc_python_api_output(place)
self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs,
place)
# we only check end2end api when check_eager=True
if hasattr(self, "python_api"):
api_outs = self._calc_python_api_output(place)
self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs,
place)
outs, fetch_list = self._calc_output(place, no_check_set=no_check_set)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册