diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index ab1554d140bc83067678abd2f148d2abb57aa433..765ca361f61f78de73003e22e38796c39e12d2e5 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -78,6 +78,15 @@ struct OpInfo { return grad_op_maker_; } + // some op has no grad_op_maker, add check before use GradOpMaker() + bool HasGradOpMaker() const { + return grad_op_maker_ != nullptr ? true : false; + } + + bool HasInferInplace() const { + return infer_inplace_ != nullptr ? true : false; + } + const OpAttrChecker* Checker() const { return checker_; } const InferNoNeedBufferVarsFN& NoNeedBufferVarsInferer() const { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 02a9eb20264a5de028f29c8a86f459ae5461ba9e..9bd70d46046fa5627533f546639f6b896e91f90a 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -67,9 +67,6 @@ static DDim GetDimsDebug(const Scope& scope, const std::string& name, if (var->IsType()) { const LoDTensor& tensor = var->Get(); - if (UNLIKELY(!tensor.IsInitialized())) { - return DDim({-1}); - } return tensor.dims(); } else if (var->IsType()) { if (get_actual_dim) { diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index 97efd21adf1a712802b65bef1d377456b712ffc9..0d106d8a6924281d347a0449cb5212fbcd0be5f1 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -1,6 +1,5 @@ conv_shift cos_sim -dequantize fc flatten fsp @@ -17,13 +16,11 @@ nce pool2d pool3d prelu -quantize rank_loss reduce_max reduce_min reduce_prod reduce_sum -requantize reshape rnn_memory_helper sequence_softmax diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 38159f84a0d56f45cfef233a3c70c3c6cef17d9f..97f49dbcb08e4428b4857f4a70ab21399fb35612 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -41,5 +41,4 @@ void DeQuantOpMaker::Make() { namespace ops = paddle::operators; -REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, - paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker); diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index f4085daa10697c39cce63b0db4e0e32fde2374d5..3111cace66c8c5602737ce4f80338725e4a8a6dd 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -260,10 +260,11 @@ class Flatten2GradOp : public framework::OperatorBase { attrs["shape"] = framework::vectorize2int(x_dims); attrs["inplace"] = false; - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape2", {{"X", {dout_name}}, {"Shape", {}}}, - {{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs); - reshape_op->Run(scope, place); + auto reshape_grad_op = framework::OpRegistry::CreateOp( + "reshape2_grad", + {{"Out@GRAD", {dout_name}}, {"Shape", {}}, {"XShape", {xshape_name}}}, + {{"X@GRAD", {dx_name}}}, attrs); + reshape_grad_op->Run(scope, place); } }; diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index bf70c08bdb82218a2d0f63f3e70a2a1093e6a542..d8e20f4c4ae6059551bfff3603a2ad6c0a7aa86d 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -43,5 +43,4 @@ void QuantOpMaker::Make() { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, - paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker); diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index 08ba1470aaddf146fe3685ff6c3cd9f3d7e16d75..d156ae207763433ea2ed7fb97a08cbe5880da3cd 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -42,5 +42,4 @@ void ReQuantOpMaker::Make() { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, - paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker); diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index ba874549ce35fcdfb7026e3368b8736460069ae2..e3f31c0ae8ecd07b2f06ea2bfa13b32e4a8bdb37 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -38,18 +38,14 @@ __global__ void SumArrayCUDAKernel(T **in, T *out, int64_t N, size_t in_size, bool read_dst) { int id = blockIdx.x * blockDim.x + threadIdx.x; while (id < N) { - T total(0); + T total(read_dst ? out[id] : static_cast(0)); for (int i = 0; i < in_size; ++i) { const T *tmp = in[i]; if (tmp) { total += tmp[id]; } } - if (read_dst) { - out[id] += total; - } else { - out[id] = total; - } + out[id] = total; id += blockDim.x * gridDim.x; } } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2303c2e6656331cd5b8189c086bcaa4cd97f23ab..5944c93a055e81b8056e43159b78d25d1a21e51b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -85,6 +85,7 @@ limitations under the License. */ DEFINE_bool(reader_queue_speed_test_mode, false, "If set true, the queue.pop will only get data from queue but not " "remove the data from queue for speed testing"); +DECLARE_bool(use_mkldnn); // disable auto conversion to list in Python PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); @@ -489,10 +490,24 @@ PYBIND11_MODULE(core_noavx, m) { Returns: out (Tensor): new Tensor(NOT LoDTensor). )DOC") - .def("__str__", [](const LoDTensor &self) { - std::stringstream ostr; - ostr << self; - return ostr.str(); + .def("__str__", + [](const LoDTensor &self) { + std::stringstream ostr; + ostr << self; + return ostr.str(); + }) + .def("_copy", [](const LoDTensor &self, const platform::Place &place) { + // follow fetch_op's inplementation + LoDTensor dst; + if (self.IsInitialized() && self.numel() > 0) { + TensorCopySync(self, place, &dst); + } else { + // Not copy, if the src tensor is empty. + dst.clear(); + dst.Resize({0}); + } + dst.set_lod(self.lod()); + return dst; }); py::class_(m, "SelectedRows") @@ -718,6 +733,14 @@ All parameter, weight, gradient are variables in Paddle. [](std::unique_ptr &p) { return p.release(); }); return std::make_pair(grad_op_desc_ptrs, grad_to_var); }); + m.def("has_grad_op_maker", [](const std::string op_type) { + return framework::OpInfoMap::Instance().Get(op_type).HasGradOpMaker(); + }); + m.def("has_infer_inplace", [](const std::string op_type) { + return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace(); + }); + m.def("get_flags_use_mkldnn", []() { return FLAGS_use_mkldnn; }); + m.def("prune", [](const ProgramDesc &origin, const std::vector> &targets) { ProgramDesc prog_with_targets(origin); diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index b43e1922d5815c8e11da57d28efd8a0cde176eea..176221a0d43acca206448eb316a83249dbc67b1f 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -16,6 +16,7 @@ from __future__ import print_function import os import unittest +import warnings import numpy as np import random import six @@ -232,10 +233,12 @@ class OpTest(unittest.TestCase): inputs=inputs, outputs=outputs, attrs=self.attrs if hasattr(self, "attrs") else dict()) - # infer variable type and infer shape in compile-time + # infer variable type and infer shape in compile-time op.desc.infer_var_type(block.desc) op.desc.infer_shape(block.desc) + return op + def _get_io_vars(self, block, numpy_inputs): inputs = {} for name, value in six.iteritems(numpy_inputs): @@ -316,7 +319,13 @@ class OpTest(unittest.TestCase): return outputs - def _calc_output(self, place, parallel=False, no_check_set=None, loss=None): + def _calc_output(self, + place, + parallel=False, + no_check_set=None, + loss=None, + enable_inplace=None, + for_inplace_grad_test=None): program = Program() block = program.global_block() self._append_ops(block) @@ -325,6 +334,14 @@ class OpTest(unittest.TestCase): outputs = self._get_outputs(block) feed_map = self.feed_var(inputs, place) + if for_inplace_grad_test is not None: + # Some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op, + # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]). + # Set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them, + # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL. + for name, var in block.vars.items(): + if 0 in var.shape: + var.persistable = True if parallel: use_cuda = False if isinstance(place, fluid.CUDAPlace): @@ -351,6 +368,15 @@ class OpTest(unittest.TestCase): # fetch_list = map(block.var, fetch_list) if not isinstance(fetch_list[0], fluid.framework.Variable): fetch_list = list(map(block.var, fetch_list)) + + if enable_inplace is not None: + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = enable_inplace + + compiled_prog = fluid.CompiledProgram(program).with_data_parallel( + build_strategy=build_strategy, places=place) + program = compiled_prog + executor = Executor(place) outs = executor.run(program, feed=feed_map, @@ -358,12 +384,168 @@ class OpTest(unittest.TestCase): return_numpy=False) return outs, fetch_list + def check_inplace_output_with_place(self, + place, + no_check_set=None, + inplace_atol=None): + # can`t enable inplace + if not fluid.core.has_infer_inplace(self.op_type): + return + expect_outs, fetch_list = self._calc_output( + place, no_check_set=no_check_set, enable_inplace=False) + actual_outs, fetch_list = self._calc_output( + place, no_check_set=no_check_set, enable_inplace=True) + + # compare expect_outs and actual_outs + for i, out in enumerate(fetch_list): + if inplace_atol is not None: + self.assertTrue( + np.allclose( + np.array(expect_outs[i]), + np.array(actual_outs[i]), + atol=inplace_atol), + "Output (" + out.name + ") has diff at " + str(place) + + " when using and not using inplace" + "\nExpect " + + str(expect_outs[i]) + "\n" + "But Got" + str(actual_outs[i]) + + " in class " + self.__class__.__name__) + else: + self.assertTrue( + np.array_equal( + np.array(expect_outs[i]), np.array(actual_outs[i])), + "Output (" + out.name + ") has diff at " + str(place) + + " when using and not using inplace" + "\nExpect " + + str(expect_outs[i]) + "\n" + "But Got" + str(actual_outs[i]) + + " in class " + self.__class__.__name__ + '\n') + + def check_inplace_grad_output_with_place(self, + place, + no_check_set=None, + inplace_atol=None): + # create forward program to get forward vars + program = Program() + block = program.global_block() + op = self._append_ops(block) + inputs = self._get_inputs(block) + outputs = self._get_outputs(block) + feed_map = self.feed_var(inputs, place) + + # get grad_op + if not fluid.core.has_grad_op_maker(op.desc.type()): + return + grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc, + set(), []) + # has grad_op_maker but no grad_op + if not grad_op_desc_list: + return + + for i, grad_op_desc in enumerate(grad_op_desc_list): + # grad_op can not inplace + if not fluid.core.has_infer_inplace(grad_op_desc.type()): + continue + # get forward outs + forward_outs, fetch_list = self._calc_output( + place, no_check_set=no_check_set, for_inplace_grad_test=True) + + # create grad program + grad_program = Program() + grad_block = grad_program.global_block() + new_op_desc = grad_block.desc.append_op() + new_op_desc.copy_from(grad_op_desc) + grad_program._sync_with_cpp() + + # create grad vars based on forward vars (shape and dtype) + for arg in grad_op_desc.input_arg_names( + ) + grad_op_desc.output_arg_names(): + forward_var_name = op_grad_to_var.get(arg, None) + if forward_var_name is None: + forward_var_name = arg + forward_var = block.vars.get(forward_var_name) + assert forward_var is not None, "{} cannot be found".format( + forward_var_name) + grad_var = grad_block.create_var( + name=arg, + dtype=forward_var.dtype, + shape=forward_var.shape, + type=forward_var.type, + persistable=False) + # some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op, + # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]). + # set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them, + # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL. + if 0 in grad_var.shape: + grad_var.persistable = True + grad_program._sync_with_cpp() + grad_fetch_list = grad_op_desc.output_arg_names() + + def _calc_grad_output(enable_inplace=None): + # generate feed_map for grad_program + # since we don`t really check gradient accuracy, but the consistency when using and not using inplace + # we use forward outs (also inputs sometimes) as grad (fake) feeds + p = core.Place() + p.set_place(place) + grad_feed_map = {} + for arg in grad_op_desc.input_arg_names(): + if arg in feed_map.keys(): + grad_feed_map[arg] = feed_map[arg]._copy(p) + else: + forward_var_name = op_grad_to_var.get(arg, None) + if forward_var_name is None: + forward_var_name = arg + for i, out in enumerate(fetch_list): + if out.name == forward_var_name: + # don't feed variables whose tensors hold no buffer (shape contains 0 like shape = [0,2,5] and holder_ is NULL), like XShape in reshape2 op. + # get them from global_scope directly since we have set them persistable in forward execution + if 0 in out.shape: + continue + else: + grad_feed_map[arg] = forward_outs[i]._copy( + p) + + exe = Executor(place) + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = enable_inplace + compiled_program = fluid.CompiledProgram( + grad_program).with_data_parallel( + build_strategy=build_strategy, places=place) + outs = exe.run(compiled_program, + feed=grad_feed_map, + fetch_list=grad_fetch_list, + return_numpy=False) + return outs + + expect_outs = _calc_grad_output(enable_inplace=False) + actual_outs = _calc_grad_output(enable_inplace=True) + + # compare expect_outs and actual_outs + for i, out_name in enumerate(grad_fetch_list): + if inplace_atol is not None: + self.assertTrue( + np.allclose( + np.array(expect_outs[i]), + np.array(actual_outs[i]), + atol=inplace_atol), + "Output (" + out_name + ") has diff at " + str(place) + + " when using and not using inplace" + "\nExpect " + + str(expect_outs[i]) + "\n" + "But Got" + + str(actual_outs[i]) + " in class " + + self.__class__.__name__) + else: + self.assertTrue( + np.array_equal( + np.array(expect_outs[i]), np.array(actual_outs[i])), + "Output (" + out_name + ") has diff at " + str(place) + + " when using and not using inplace" + "\nExpect " + + str(expect_outs[i]) + "\n" + "But Got" + + str(actual_outs[i]) + " in class " + + self.__class__.__name__) + def check_output_with_place(self, place, atol, no_check_set=None, equal_nan=False, - check_dygraph=False): + check_dygraph=False, + inplace_atol=None): if check_dygraph: dygraph_outs = self._calc_dygraph_output( place, no_check_set=no_check_set) @@ -464,6 +646,25 @@ class OpTest(unittest.TestCase): "Output (" + out_name + ") has different lod at " + str(place) + " in dygraph mode") + # inplace_atol only used when op doesn't ensure computational consistency + if inplace_atol is not None: + warnings.warn( + "By default, inplace_atol should not be set, please check it") + self.check_inplace_output_with_place( + place, no_check_set=no_check_set, inplace_atol=inplace_atol) + + # TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn + # skip use_mkldnn currently + flags_use_mkldnn = fluid.core.get_flags_use_mkldnn() + attrs_use_mkldnn = hasattr( + self, 'attrs') and bool(self.attrs.get('use_mkldnn', False)) + if flags_use_mkldnn or attrs_use_mkldnn: + warnings.warn( + "check inplace_grad for ops using mkldnn is not supported") + return + self.check_inplace_grad_output_with_place( + place, no_check_set=no_check_set, inplace_atol=inplace_atol) + def _get_places(self): if self.dtype == np.float16: if core.is_compiled_with_cuda() and core.op_support_gpu( @@ -489,7 +690,8 @@ class OpTest(unittest.TestCase): atol=1e-5, no_check_set=None, equal_nan=False, - check_dygraph=False): + check_dygraph=False, + inplace_atol=None): places = self._get_places() for place in places: self.check_output_with_place(place, atol, no_check_set, equal_nan, diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op.py b/python/paddle/fluid/tests/unittests/test_group_norm_op.py index 0b6d039f050898793b69312f50f6709d66d080cd..386c3b1f0e438dc50943009f0fe8663838a32ecc 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op.py @@ -61,11 +61,15 @@ class TestGroupNormOp(OpTest): def test_check_output(self): atol = 1e-4 + inplace_atol = 1e-4 place = core.CPUPlace() - self.check_output_with_place(place, atol=atol) + # add inplace_atol bacause group_norm doesn't ensure computational consistency + self.check_output_with_place( + place, atol=atol, inplace_atol=inplace_atol) if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=atol) + self.check_output_with_place( + place, atol=atol, inplace_atol=inplace_atol) def do_compare_between_place(self): if not core.is_compiled_with_cuda(): return