提交 a9d5fc51 编写于 作者: L Leo Chen 提交者: Zeng Jinle

Enhance OpTest to check the consistency of operators when using and not using inplace (#19101)

* add pybind interface to get all inplace ops, test=develop

* enhance OpTest to check whether the consistency of operator when using and not using inplace, test=develop

* handle corner cases in op_test, test=develop

* support outputs without tensor holder_, like XShape in reshape_op, test=develop

* fix bug, some op has GradOpMaker, but actually no grad_op in OpInfoMap, test=develop

* use reshape_grad instead of reshape in FlattenGradOp, test=develop

* fix error debug dims info for variables like XShape, test=develop

* change computational order in sum_op to relieve computation difference using inplace, test=develop

* add inplace_atol to check group_norm, and skip inplace_grad for mkldnn, test=develop

* follow sneaxiy's comments, test=develop

* remove unused DefaultGradOpDescMaker in mkldnn op, test=develop
上级 0d29cf18
...@@ -78,6 +78,15 @@ struct OpInfo { ...@@ -78,6 +78,15 @@ struct OpInfo {
return grad_op_maker_; return grad_op_maker_;
} }
// some op has no grad_op_maker, add check before use GradOpMaker()
bool HasGradOpMaker() const {
return grad_op_maker_ != nullptr ? true : false;
}
bool HasInferInplace() const {
return infer_inplace_ != nullptr ? true : false;
}
const OpAttrChecker* Checker() const { return checker_; } const OpAttrChecker* Checker() const { return checker_; }
const InferNoNeedBufferVarsFN& NoNeedBufferVarsInferer() const { const InferNoNeedBufferVarsFN& NoNeedBufferVarsInferer() const {
......
...@@ -67,9 +67,6 @@ static DDim GetDimsDebug(const Scope& scope, const std::string& name, ...@@ -67,9 +67,6 @@ static DDim GetDimsDebug(const Scope& scope, const std::string& name,
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>(); const LoDTensor& tensor = var->Get<LoDTensor>();
if (UNLIKELY(!tensor.IsInitialized())) {
return DDim({-1});
}
return tensor.dims(); return tensor.dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
if (get_actual_dim) { if (get_actual_dim) {
......
conv_shift conv_shift
cos_sim cos_sim
dequantize
fc fc
flatten flatten
fsp fsp
...@@ -17,13 +16,11 @@ nce ...@@ -17,13 +16,11 @@ nce
pool2d pool2d
pool3d pool3d
prelu prelu
quantize
rank_loss rank_loss
reduce_max reduce_max
reduce_min reduce_min
reduce_prod reduce_prod
reduce_sum reduce_sum
requantize
reshape reshape
rnn_memory_helper rnn_memory_helper
sequence_softmax sequence_softmax
......
...@@ -41,5 +41,4 @@ void DeQuantOpMaker::Make() { ...@@ -41,5 +41,4 @@ void DeQuantOpMaker::Make() {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker);
paddle::framework::DefaultGradOpDescMaker<true>);
...@@ -260,10 +260,11 @@ class Flatten2GradOp : public framework::OperatorBase { ...@@ -260,10 +260,11 @@ class Flatten2GradOp : public framework::OperatorBase {
attrs["shape"] = framework::vectorize2int(x_dims); attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = false; attrs["inplace"] = false;
auto reshape_op = framework::OpRegistry::CreateOp( auto reshape_grad_op = framework::OpRegistry::CreateOp(
"reshape2", {{"X", {dout_name}}, {"Shape", {}}}, "reshape2_grad",
{{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs); {{"Out@GRAD", {dout_name}}, {"Shape", {}}, {"XShape", {xshape_name}}},
reshape_op->Run(scope, place); {{"X@GRAD", {dx_name}}}, attrs);
reshape_grad_op->Run(scope, place);
} }
}; };
......
...@@ -43,5 +43,4 @@ void QuantOpMaker::Make() { ...@@ -43,5 +43,4 @@ void QuantOpMaker::Make() {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker);
paddle::framework::DefaultGradOpDescMaker<true>);
...@@ -42,5 +42,4 @@ void ReQuantOpMaker::Make() { ...@@ -42,5 +42,4 @@ void ReQuantOpMaker::Make() {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker);
paddle::framework::DefaultGradOpDescMaker<true>);
...@@ -38,18 +38,14 @@ __global__ void SumArrayCUDAKernel(T **in, T *out, int64_t N, size_t in_size, ...@@ -38,18 +38,14 @@ __global__ void SumArrayCUDAKernel(T **in, T *out, int64_t N, size_t in_size,
bool read_dst) { bool read_dst) {
int id = blockIdx.x * blockDim.x + threadIdx.x; int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) { while (id < N) {
T total(0); T total(read_dst ? out[id] : static_cast<T>(0));
for (int i = 0; i < in_size; ++i) { for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i]; const T *tmp = in[i];
if (tmp) { if (tmp) {
total += tmp[id]; total += tmp[id];
} }
} }
if (read_dst) {
out[id] += total;
} else {
out[id] = total; out[id] = total;
}
id += blockDim.x * gridDim.x; id += blockDim.x * gridDim.x;
} }
} }
......
...@@ -85,6 +85,7 @@ limitations under the License. */ ...@@ -85,6 +85,7 @@ limitations under the License. */
DEFINE_bool(reader_queue_speed_test_mode, false, DEFINE_bool(reader_queue_speed_test_mode, false,
"If set true, the queue.pop will only get data from queue but not " "If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing"); "remove the data from queue for speed testing");
DECLARE_bool(use_mkldnn);
// disable auto conversion to list in Python // disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);
...@@ -489,10 +490,24 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -489,10 +490,24 @@ PYBIND11_MODULE(core_noavx, m) {
Returns: Returns:
out (Tensor): new Tensor(NOT LoDTensor). out (Tensor): new Tensor(NOT LoDTensor).
)DOC") )DOC")
.def("__str__", [](const LoDTensor &self) { .def("__str__",
[](const LoDTensor &self) {
std::stringstream ostr; std::stringstream ostr;
ostr << self; ostr << self;
return ostr.str(); return ostr.str();
})
.def("_copy", [](const LoDTensor &self, const platform::Place &place) {
// follow fetch_op's inplementation
LoDTensor dst;
if (self.IsInitialized() && self.numel() > 0) {
TensorCopySync(self, place, &dst);
} else {
// Not copy, if the src tensor is empty.
dst.clear();
dst.Resize({0});
}
dst.set_lod(self.lod());
return dst;
}); });
py::class_<SelectedRows>(m, "SelectedRows") py::class_<SelectedRows>(m, "SelectedRows")
...@@ -718,6 +733,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -718,6 +733,14 @@ All parameter, weight, gradient are variables in Paddle.
[](std::unique_ptr<OpDesc> &p) { return p.release(); }); [](std::unique_ptr<OpDesc> &p) { return p.release(); });
return std::make_pair(grad_op_desc_ptrs, grad_to_var); return std::make_pair(grad_op_desc_ptrs, grad_to_var);
}); });
m.def("has_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasGradOpMaker();
});
m.def("has_infer_inplace", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace();
});
m.def("get_flags_use_mkldnn", []() { return FLAGS_use_mkldnn; });
m.def("prune", [](const ProgramDesc &origin, m.def("prune", [](const ProgramDesc &origin,
const std::vector<std::array<size_t, 2>> &targets) { const std::vector<std::array<size_t, 2>> &targets) {
ProgramDesc prog_with_targets(origin); ProgramDesc prog_with_targets(origin);
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import os import os
import unittest import unittest
import warnings
import numpy as np import numpy as np
import random import random
import six import six
...@@ -236,6 +237,8 @@ class OpTest(unittest.TestCase): ...@@ -236,6 +237,8 @@ class OpTest(unittest.TestCase):
op.desc.infer_var_type(block.desc) op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc) op.desc.infer_shape(block.desc)
return op
def _get_io_vars(self, block, numpy_inputs): def _get_io_vars(self, block, numpy_inputs):
inputs = {} inputs = {}
for name, value in six.iteritems(numpy_inputs): for name, value in six.iteritems(numpy_inputs):
...@@ -316,7 +319,13 @@ class OpTest(unittest.TestCase): ...@@ -316,7 +319,13 @@ class OpTest(unittest.TestCase):
return outputs return outputs
def _calc_output(self, place, parallel=False, no_check_set=None, loss=None): def _calc_output(self,
place,
parallel=False,
no_check_set=None,
loss=None,
enable_inplace=None,
for_inplace_grad_test=None):
program = Program() program = Program()
block = program.global_block() block = program.global_block()
self._append_ops(block) self._append_ops(block)
...@@ -325,6 +334,14 @@ class OpTest(unittest.TestCase): ...@@ -325,6 +334,14 @@ class OpTest(unittest.TestCase):
outputs = self._get_outputs(block) outputs = self._get_outputs(block)
feed_map = self.feed_var(inputs, place) feed_map = self.feed_var(inputs, place)
if for_inplace_grad_test is not None:
# Some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op,
# and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]).
# Set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them,
# since feed op calls check_memory_size() which fails when tensor's holder_ is NULL.
for name, var in block.vars.items():
if 0 in var.shape:
var.persistable = True
if parallel: if parallel:
use_cuda = False use_cuda = False
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
...@@ -351,6 +368,15 @@ class OpTest(unittest.TestCase): ...@@ -351,6 +368,15 @@ class OpTest(unittest.TestCase):
# fetch_list = map(block.var, fetch_list) # fetch_list = map(block.var, fetch_list)
if not isinstance(fetch_list[0], fluid.framework.Variable): if not isinstance(fetch_list[0], fluid.framework.Variable):
fetch_list = list(map(block.var, fetch_list)) fetch_list = list(map(block.var, fetch_list))
if enable_inplace is not None:
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = enable_inplace
compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
build_strategy=build_strategy, places=place)
program = compiled_prog
executor = Executor(place) executor = Executor(place)
outs = executor.run(program, outs = executor.run(program,
feed=feed_map, feed=feed_map,
...@@ -358,12 +384,168 @@ class OpTest(unittest.TestCase): ...@@ -358,12 +384,168 @@ class OpTest(unittest.TestCase):
return_numpy=False) return_numpy=False)
return outs, fetch_list return outs, fetch_list
def check_inplace_output_with_place(self,
place,
no_check_set=None,
inplace_atol=None):
# can`t enable inplace
if not fluid.core.has_infer_inplace(self.op_type):
return
expect_outs, fetch_list = self._calc_output(
place, no_check_set=no_check_set, enable_inplace=False)
actual_outs, fetch_list = self._calc_output(
place, no_check_set=no_check_set, enable_inplace=True)
# compare expect_outs and actual_outs
for i, out in enumerate(fetch_list):
if inplace_atol is not None:
self.assertTrue(
np.allclose(
np.array(expect_outs[i]),
np.array(actual_outs[i]),
atol=inplace_atol),
"Output (" + out.name + ") has diff at " + str(place) +
" when using and not using inplace" + "\nExpect " +
str(expect_outs[i]) + "\n" + "But Got" + str(actual_outs[i])
+ " in class " + self.__class__.__name__)
else:
self.assertTrue(
np.array_equal(
np.array(expect_outs[i]), np.array(actual_outs[i])),
"Output (" + out.name + ") has diff at " + str(place) +
" when using and not using inplace" + "\nExpect " +
str(expect_outs[i]) + "\n" + "But Got" + str(actual_outs[i])
+ " in class " + self.__class__.__name__ + '\n')
def check_inplace_grad_output_with_place(self,
place,
no_check_set=None,
inplace_atol=None):
# create forward program to get forward vars
program = Program()
block = program.global_block()
op = self._append_ops(block)
inputs = self._get_inputs(block)
outputs = self._get_outputs(block)
feed_map = self.feed_var(inputs, place)
# get grad_op
if not fluid.core.has_grad_op_maker(op.desc.type()):
return
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc,
set(), [])
# has grad_op_maker but no grad_op
if not grad_op_desc_list:
return
for i, grad_op_desc in enumerate(grad_op_desc_list):
# grad_op can not inplace
if not fluid.core.has_infer_inplace(grad_op_desc.type()):
continue
# get forward outs
forward_outs, fetch_list = self._calc_output(
place, no_check_set=no_check_set, for_inplace_grad_test=True)
# create grad program
grad_program = Program()
grad_block = grad_program.global_block()
new_op_desc = grad_block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
grad_program._sync_with_cpp()
# create grad vars based on forward vars (shape and dtype)
for arg in grad_op_desc.input_arg_names(
) + grad_op_desc.output_arg_names():
forward_var_name = op_grad_to_var.get(arg, None)
if forward_var_name is None:
forward_var_name = arg
forward_var = block.vars.get(forward_var_name)
assert forward_var is not None, "{} cannot be found".format(
forward_var_name)
grad_var = grad_block.create_var(
name=arg,
dtype=forward_var.dtype,
shape=forward_var.shape,
type=forward_var.type,
persistable=False)
# some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op,
# and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]).
# set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them,
# since feed op calls check_memory_size() which fails when tensor's holder_ is NULL.
if 0 in grad_var.shape:
grad_var.persistable = True
grad_program._sync_with_cpp()
grad_fetch_list = grad_op_desc.output_arg_names()
def _calc_grad_output(enable_inplace=None):
# generate feed_map for grad_program
# since we don`t really check gradient accuracy, but the consistency when using and not using inplace
# we use forward outs (also inputs sometimes) as grad (fake) feeds
p = core.Place()
p.set_place(place)
grad_feed_map = {}
for arg in grad_op_desc.input_arg_names():
if arg in feed_map.keys():
grad_feed_map[arg] = feed_map[arg]._copy(p)
else:
forward_var_name = op_grad_to_var.get(arg, None)
if forward_var_name is None:
forward_var_name = arg
for i, out in enumerate(fetch_list):
if out.name == forward_var_name:
# don't feed variables whose tensors hold no buffer (shape contains 0 like shape = [0,2,5] and holder_ is NULL), like XShape in reshape2 op.
# get them from global_scope directly since we have set them persistable in forward execution
if 0 in out.shape:
continue
else:
grad_feed_map[arg] = forward_outs[i]._copy(
p)
exe = Executor(place)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = enable_inplace
compiled_program = fluid.CompiledProgram(
grad_program).with_data_parallel(
build_strategy=build_strategy, places=place)
outs = exe.run(compiled_program,
feed=grad_feed_map,
fetch_list=grad_fetch_list,
return_numpy=False)
return outs
expect_outs = _calc_grad_output(enable_inplace=False)
actual_outs = _calc_grad_output(enable_inplace=True)
# compare expect_outs and actual_outs
for i, out_name in enumerate(grad_fetch_list):
if inplace_atol is not None:
self.assertTrue(
np.allclose(
np.array(expect_outs[i]),
np.array(actual_outs[i]),
atol=inplace_atol),
"Output (" + out_name + ") has diff at " + str(place) +
" when using and not using inplace" + "\nExpect " +
str(expect_outs[i]) + "\n" + "But Got" +
str(actual_outs[i]) + " in class " +
self.__class__.__name__)
else:
self.assertTrue(
np.array_equal(
np.array(expect_outs[i]), np.array(actual_outs[i])),
"Output (" + out_name + ") has diff at " + str(place) +
" when using and not using inplace" + "\nExpect " +
str(expect_outs[i]) + "\n" + "But Got" +
str(actual_outs[i]) + " in class " +
self.__class__.__name__)
def check_output_with_place(self, def check_output_with_place(self,
place, place,
atol, atol,
no_check_set=None, no_check_set=None,
equal_nan=False, equal_nan=False,
check_dygraph=False): check_dygraph=False,
inplace_atol=None):
if check_dygraph: if check_dygraph:
dygraph_outs = self._calc_dygraph_output( dygraph_outs = self._calc_dygraph_output(
place, no_check_set=no_check_set) place, no_check_set=no_check_set)
...@@ -464,6 +646,25 @@ class OpTest(unittest.TestCase): ...@@ -464,6 +646,25 @@ class OpTest(unittest.TestCase):
"Output (" + out_name + ") has different lod at " + "Output (" + out_name + ") has different lod at " +
str(place) + " in dygraph mode") str(place) + " in dygraph mode")
# inplace_atol only used when op doesn't ensure computational consistency
if inplace_atol is not None:
warnings.warn(
"By default, inplace_atol should not be set, please check it")
self.check_inplace_output_with_place(
place, no_check_set=no_check_set, inplace_atol=inplace_atol)
# TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn
# skip use_mkldnn currently
flags_use_mkldnn = fluid.core.get_flags_use_mkldnn()
attrs_use_mkldnn = hasattr(
self, 'attrs') and bool(self.attrs.get('use_mkldnn', False))
if flags_use_mkldnn or attrs_use_mkldnn:
warnings.warn(
"check inplace_grad for ops using mkldnn is not supported")
return
self.check_inplace_grad_output_with_place(
place, no_check_set=no_check_set, inplace_atol=inplace_atol)
def _get_places(self): def _get_places(self):
if self.dtype == np.float16: if self.dtype == np.float16:
if core.is_compiled_with_cuda() and core.op_support_gpu( if core.is_compiled_with_cuda() and core.op_support_gpu(
...@@ -489,7 +690,8 @@ class OpTest(unittest.TestCase): ...@@ -489,7 +690,8 @@ class OpTest(unittest.TestCase):
atol=1e-5, atol=1e-5,
no_check_set=None, no_check_set=None,
equal_nan=False, equal_nan=False,
check_dygraph=False): check_dygraph=False,
inplace_atol=None):
places = self._get_places() places = self._get_places()
for place in places: for place in places:
self.check_output_with_place(place, atol, no_check_set, equal_nan, self.check_output_with_place(place, atol, no_check_set, equal_nan,
......
...@@ -61,11 +61,15 @@ class TestGroupNormOp(OpTest): ...@@ -61,11 +61,15 @@ class TestGroupNormOp(OpTest):
def test_check_output(self): def test_check_output(self):
atol = 1e-4 atol = 1e-4
inplace_atol = 1e-4
place = core.CPUPlace() place = core.CPUPlace()
self.check_output_with_place(place, atol=atol) # add inplace_atol bacause group_norm doesn't ensure computational consistency
self.check_output_with_place(
place, atol=atol, inplace_atol=inplace_atol)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol) self.check_output_with_place(
place, atol=atol, inplace_atol=inplace_atol)
def do_compare_between_place(self): def do_compare_between_place(self):
if not core.is_compiled_with_cuda(): return if not core.is_compiled_with_cuda(): return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册