From 79a32715b9aca4a6e522ffcf91bac82e7a6cd380 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 7 Mar 2022 17:24:16 +0800 Subject: [PATCH] [OpTest] Support to test paddle API end-to-end for check_eager (#40169) * add python api test in TestOp * test_python_api if self.python_api is set * fix code by CR --- paddle/fluid/imperative/tracer.cc | 33 +++++++ paddle/fluid/imperative/tracer.h | 5 + paddle/fluid/pybind/imperative.cc | 21 +++++ .../paddle/fluid/tests/unittests/op_test.py | 94 +++++++++++++++++++ .../fluid/tests/unittests/test_selu_op.py | 1 + 5 files changed, 154 insertions(+) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 4336a5c77c..01c9d2847e 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -18,12 +18,14 @@ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/imperative/amp_auto_cast.h" +#include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/platform/denormal.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/string/string_helper.h" +#include "paddle/phi/common/place.h" DECLARE_bool(use_mkldnn); DECLARE_string(tracer_mkldnn_ops_on); @@ -382,5 +384,36 @@ bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins, return false; } +phi::KernelSignature Tracer::GetExpectedKernelSignature( + const std::string& type, const NameVarBaseMap& ins, + const NameVarBaseMap& outs, framework::AttributeMap attrs) const { + auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); + framework::RuntimeContext ctx({}, {}); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(phi::CPUPlace()); + const auto& op_info = op->Info(); + auto* attr_checker = op_info.Checker(); + if (attr_checker) { + attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true); + } + static paddle::framework::AttributeMap empty_attrs_map = {}; + const paddle::framework::AttributeMap& default_attrs = + attr_checker == nullptr ? empty_attrs_map + : attr_checker->GetDefaultAttrMap(); + auto dygraph_exe_ctx = + imperative::DygraphExecutionContext( + *op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, + default_attrs); + auto* opbase_with_kernel = + dynamic_cast(op.get()); + PADDLE_ENFORCE_NE(opbase_with_kernel, nullptr, + platform::errors::InvalidArgument( + "This op type:`%s` is not a OperatorWithKernel, only " + "OperatorWithKernel can get KernelSignature", + type)); + return phi::KernelSignature( + std::move(opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx))); +} + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 73ecbbe614..fd13fce6a6 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -28,6 +28,7 @@ #include "paddle/fluid/imperative/jit/program_desc_tracer.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/platform/macros.h" +#include "paddle/phi/core/compat/arg_map_context.h" namespace paddle { namespace imperative { @@ -154,6 +155,10 @@ class Tracer { } } + phi::KernelSignature GetExpectedKernelSignature( + const std::string& type, const NameVarBaseMap& ins, + const NameVarBaseMap& outs, framework::AttributeMap attrs) const; + paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( const platform::Place& place); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 3da17b95a6..9b373a5818 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -56,6 +56,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/slice_utils.h" #include "paddle/fluid/pybind/tensor_py.h" +#include "paddle/phi/core/compat/arg_map_context.h" namespace paddle { namespace pybind { @@ -2073,6 +2074,26 @@ void BindImperative(py::module *m_ptr) { *(imperative::AmpOperators::Instance().GetMutableAllowOps()), *(imperative::AmpOperators::Instance().GetMutableBlockOps())); }) + .def("_get_kernel_signature", + [](imperative::Tracer &self, const std::string &type, + const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, + framework::AttributeMap attrs) { + // TODO(xiongkun): move this function outside of tracer. + auto ins_map = ConvertToNameVarBaseMap(ins); + auto outs_map = ConvertToNameVarBaseMap(outs); + { + auto to_vector = [](paddle::SmallVector &vec) { + return std::vector(vec.begin(), vec.end()); + }; + auto ret = self.GetExpectedKernelSignature(type, ins_map, + outs_map, attrs); + auto kernelsig_ins = to_vector(std::get<0>(ret.args)); + auto kernelsig_attrs = to_vector(std::get<1>(ret.args)); + auto kernelsig_outs = to_vector(std::get<2>(ret.args)); + return std::make_tuple(kernelsig_ins, kernelsig_attrs, + kernelsig_outs); + } + }) .def("trace", [](imperative::Tracer &self, const std::string &type, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 628791afef..0c7f269a08 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -29,6 +29,7 @@ from copy import copy import paddle import paddle.fluid as fluid +from paddle.fluid.framework import _dygraph_tracer import paddle.fluid.core as core from paddle.fluid.framework import _in_eager_mode from paddle.fluid.framework import _test_eager_guard @@ -395,6 +396,7 @@ class OpTest(unittest.TestCase): hasattr(self, "attrs") and "use_xpu" in self.attrs and self.attrs["use_xpu"] == True) + # set the self.output_dtype . def infer_dtype_from_inputs_outputs(self, inputs, outputs): def is_np_data(input): return isinstance(input, (np.ndarray, np.generic)) @@ -679,6 +681,91 @@ class OpTest(unittest.TestCase): else: return var_dict + def _check_api_outs_by_dygraph_outs(self, api_outs, dygraph_outs, place): + """ for quick verify, here we take a simplest strategy: + 1. we only check variable in api_outs. + 2. we simply check the numpy (tensor) . + 3. we set atol and rtol as 1e-5, because they are unrelated to dtype. + """ + for name in api_outs: + np_api = np.array(api_outs[name]) + np_dyg = np.array(dygraph_outs[name]) + self.assertTrue( + np.allclose( + np_api, np_dyg, equal_nan=False), + "Output (" + name + ") has diff at " + str(place) + "\nExpect " + + str(np_dyg) + "\n" + "But Got" + str(np_api) + " in class " + + self.__class__.__name__) + + def _calc_python_api_output(self, place): + def prepare_python_api_arguments(op_proto_ins, op_proto_attrs, + kernel_sig): + """ map from `op proto inputs and attrs` to `api input list and api attrs dict` + """ + # NOTE(xiongkun): why don't use input arguments dicts ? + # Because we don't know the python api name of each arguments. + inputs_sig, attrs_sig, outputs_sig = kernel_sig + input_arguments = [op_proto_ins[name] for name in inputs_sig] + attr_arguments = { + name: op_proto_attrs[name] + for name in attrs_sig if name in op_proto_attrs + } + return input_arguments, attr_arguments + + def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): + if not isinstance(ret_tuple, (tuple, list)): + ret_tuple = [ret_tuple] + assert len(output_sig) == len( + ret_tuple), "expect %d outputs, but get %d outputs" % ( + len(output_sig), len(ret_tuple)) + return {a: b for a, b in zip(output_sig, ret_tuple)} + + def assumption_assert_and_transform(args, argvs): + """ + currently only support "X" is [Tensor], don't support multi-tensor in "X" + """ + for inp in args: + assert isinstance(inp, list) and len( + inp + ) == 1, "currently only support `X` is [Tensor], don't support multi-tensor in `X`" + args = [inp[0] for inp in args] + return args, argvs + + def cal_python_api(python_api, args, argvs, kernel_sig): + args, argvs = assumption_assert_and_transform(args, argvs) + inputs_sig, attrs_sig, outputs_sig = kernel_sig + ret_tuple = python_api(*args, **argvs) + return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig) + + with fluid.dygraph.base.guard(place=place): + block = fluid.default_main_program().global_block() + op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) + # prepare input variable + inputs = self.append_input_output_for_dygraph(op_proto, self.inputs, + True, False, block) + # prepare output variable + outputs = self.append_input_output_for_dygraph( + op_proto, self.outputs, False, False, block) + + # prepare attrbutes + attrs_outputs = {} + if hasattr(self, "attrs"): + for attrs_name in self.attrs: + if self.attrs[attrs_name] is not None: + attrs_outputs[attrs_name] = self.attrs[attrs_name] + + kernel_sig = _dygraph_tracer()._get_kernel_signature( + self.op_type, inputs, outputs, attrs_outputs) + + assert hasattr( + self, "python_api" + ), "Please set the `self.python_api` if you want to compare python api output." + arg, argv = prepare_python_api_arguments(inputs, attrs_outputs, + kernel_sig) + """ we directly return the cal_python_api value because the value is already tensor. + """ + return cal_python_api(self.python_api, arg, argv, kernel_sig) + def _calc_dygraph_output(self, place, parallel=False, no_check_set=None): self.__class__.op_type = self.op_type # for ci check, please not delete it for now with fluid.dygraph.base.guard(place=place): @@ -699,6 +786,7 @@ class OpTest(unittest.TestCase): for attrs_name in self.attrs: if self.attrs[attrs_name] is not None: attrs_outputs[attrs_name] = self.attrs[attrs_name] + block.append_op( type=self.op_type, inputs=inputs, @@ -1150,6 +1238,12 @@ class OpTest(unittest.TestCase): if check_dygraph: dygraph_outs = self._calc_dygraph_output( place, no_check_set=no_check_set) + + if hasattr(self, "python_api"): + api_outs = self._calc_python_api_output(place) + self._check_api_outs_by_dygraph_outs(api_outs, dygraph_outs, + place) + if check_eager: with _test_eager_guard(): eager_dygraph_outs = self._calc_dygraph_output( diff --git a/python/paddle/fluid/tests/unittests/test_selu_op.py b/python/paddle/fluid/tests/unittests/test_selu_op.py index e71adae8d9..f161988179 100644 --- a/python/paddle/fluid/tests/unittests/test_selu_op.py +++ b/python/paddle/fluid/tests/unittests/test_selu_op.py @@ -42,6 +42,7 @@ def ref_selu(x, class SeluTest(OpTest): def setUp(self): self.op_type = "selu" + self.python_api = paddle.nn.functional.selu self.x_shape = [3, 5, 5, 10] self.dtype = np.float64 self.init_x_shape() -- GitLab