diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 1c20cf9cc200fc1d5e65076cec78d5e1b143831c..6629a203f80ede3883860a630861f27e7edbe977 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -509,7 +509,8 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel); REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, - paddle::operators::CUDNNConvGradOpKernel); + paddle::operators::CUDNNConvGradOpKernel, + paddle::operators::CUDNNConvGradOpKernel); REGISTER_OP_KERNEL( conv3d_grad_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvDoubleGradOpKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index aa8bfdf9d1689b4b47224d872b2e8eebd35fa9eb..581caad62ed5d382af8957631ff8dbdbc401b1cb 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -186,8 +186,7 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel { } }; -DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"}, - {"X", framework::GradVarName("X")}, - {"Y", framework::GradVarName("Y")}); +DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"}); + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 19580fc804987f47d8a5498fe2aaa5b37101d13c..aed0008350be7ce4e93e75ee1a5aeb5f75e71175 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -318,33 +318,6 @@ class OpTest(unittest.TestCase): attrs=self.attrs) return outputs - def _compare_expect_and_actual_outputs(self, - place, - fetch_list, - expect_outs, - actual_outs, - inplace_atol=None): - # compare expect_outs and actual_outs - for i, name in enumerate(fetch_list): - if inplace_atol is not None: - self.assertTrue( - np.allclose( - np.array(expect_outs[i]), - np.array(actual_outs[i]), - atol=inplace_atol), - "Output (" + name + ") has diff at " + str(place) + - " when using and not using inplace" + "\nExpect " + - str(expect_outs[i]) + "\n" + "But Got" + str(actual_outs[i]) - + " in class " + self.__class__.__name__) - else: - self.assertTrue( - np.array_equal( - np.array(expect_outs[i]), np.array(actual_outs[i])), - "Output (" + name + ") has diff at " + str(place) + - " when using and not using inplace" + "\nExpect " + - str(expect_outs[i]) + "\n" + "But Got" + str(actual_outs[i]) - + " in class " + self.__class__.__name__ + '\n') - def _calc_output(self, place, parallel=False, @@ -365,7 +338,8 @@ class OpTest(unittest.TestCase): # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]). # Set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them, # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL. - for name, var in block.vars.items(): + for out_name in op.output_arg_names: + var = block.var(out_name) if 0 in var.shape: var.persistable = True original_program = program @@ -411,13 +385,189 @@ class OpTest(unittest.TestCase): else: return outs, fetch_list - def check_inplace_output_with_place(self, - place, - no_check_set=None, - inplace_atol=None): - # can`t enable inplace - if not fluid.core.has_infer_inplace(self.op_type): - return + def _compare_expect_and_actual_outputs(self, + place, + fetch_list, + expect_outs, + actual_outs, + inplace_atol=None): + """Compare expect outs and actual outs of an tested op. + + Args: + place (CPUPlace | CUDAPlace): The place where the op runs. + fetch_list (list): The outputs of tested op. + expect_outs (list): The expect outs of tested op. + actual_outs (list): The actual outs of tested op. + inplace_atol (float): The tolerable error, only set when tested op doesn't ensure computational consistency, like group_norm op. + + Returns: + None. + """ + # compare expect_outs and actual_outs + for i, name in enumerate(fetch_list): + if inplace_atol is not None: + self.assertTrue( + np.allclose( + np.array(expect_outs[i]), + np.array(actual_outs[i]), + atol=inplace_atol), + "Output (" + name + ") has diff at " + str(place) + + " when using and not using inplace" + "\nExpect " + + str(expect_outs[i]) + "\n" + "But Got" + str(actual_outs[i]) + + " in class " + self.__class__.__name__) + else: + self.assertTrue( + np.array_equal( + np.array(expect_outs[i]), np.array(actual_outs[i])), + "Output (" + name + ") has diff at " + str(place) + + " when using and not using inplace" + "\nExpect " + + str(expect_outs[i]) + "\n" + "But Got" + str(actual_outs[i]) + + " in class " + self.__class__.__name__ + '\n') + + def _construct_grad_program_from_forward(self, fwd_program, grad_op_desc, + op_grad_to_var): + """Generate grad_program which contains the grad_op. + + Args: + fwd_program (tuple): The program that contains grad_op_desc's corresponding forward op. + grad_op_desc (OpDesc): The OpDesc of grad op. + op_grad_to_var (dict): The relation of variables in grad op and its forward op. + + Returns: + grad_program (program): The program which contains the grad_op. + """ + grad_program = Program() + grad_block = grad_program.global_block() + new_op_desc = grad_block.desc.append_op() + new_op_desc.copy_from(grad_op_desc) + grad_program._sync_with_cpp() + + # Create grad vars based on fwd vars (shape and dtype) + for arg in grad_op_desc.input_arg_names( + ) + grad_op_desc.output_arg_names(): + fwd_var_name = op_grad_to_var.get(arg, None) + if fwd_var_name is None: + fwd_var_name = arg + fwd_var = fwd_program.global_block().vars.get(fwd_var_name) + assert fwd_var is not None, "{} cannot be found".format( + fwd_var_name) + grad_var = grad_block.create_var( + name=arg, + dtype=fwd_var.dtype, + shape=fwd_var.shape, + type=fwd_var.type, + persistable=False) + + # Some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op, + # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]). + # Set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them, + # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL. + if 0 in grad_var.shape: + grad_var.persistable = True + grad_program._sync_with_cpp() + return grad_program + + def _construct_grad_feed_map_from_forward(self, place, fwd_res, + grad_op_desc, op_grad_to_var): + """Generate grad_feed_map for grad_program. + + since we don`t really check gradient accuracy, but check the consistency when using and not using inplace, + we use fwd outs (also inputs sometimes) to construct grad inputs. + + Args: + place (CPUPlace | CUDAPlace): The place where the op runs. + fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True. + i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc) + grad_op_desc (OpDesc): The OpDesc of grad op. + op_grad_to_var (dict): The relation of variables in grad op and its fwd_op. + + Returns: + grad_feed_map (dict): The feed_map of grad_op. + """ + fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc = fwd_res + p = core.Place() + p.set_place(place) + grad_feed_map = {} + for arg in grad_op_desc.input_arg_names(): + if arg in fwd_feed_map.keys(): + grad_feed_map[arg] = fwd_feed_map[arg]._copy(p) + else: + fwd_var_name = op_grad_to_var.get(arg, None) + if fwd_var_name is None: + fwd_var_name = arg + + for i, out_name in enumerate(fwd_fetch_list): + if out_name == fwd_var_name: + # don't feed variables whose tensors hold no buffer (shape contains 0 like shape = [0,2,5] and holder_ is NULL), like XShape in reshape2 op. + # get them from global_scope directly since we have set them persistable in fwd execution + if 0 in fwd_program.global_block().var(out_name).shape: + continue + else: + grad_feed_map[arg] = fwd_outs[i]._copy(p) + return grad_feed_map + + def _get_need_run_ops(self, op_desc, fwd_op_desc=None): + """Postorder traversal of the 'grad' tree to get all ops that need to run during inplace test. + An op needs to run druing inplace check if, + (1) it has infer_inplace, + (2) it has infer_inplace in its grad descendants. (since we need its outputs as to construct its grad's inputs) + + Args: + op_desc (OpDesc): The op_desc of current op. + fwd_op_desc (OpDesc): The op_desc of current op's forward op, None if current op has no forward op. + Eg. relu's fwd_op is None, relu_grad's fwd_op is relu, relu_grad_grad's fwd_op is relu_grad, etc. + + Returns: + need_run_ops (list[(op_desc, fwd_op_desc)]): The ops that need to run during inplace test. + """ + need_run_ops = [] + visited_ops = [] + + def _dfs_grad_op(op_desc, fwd_op_desc=None): + visited_ops.append(op_desc.type()) + has_infer_inplace = fluid.core.has_infer_inplace(op_desc.type()) + has_grad_op_maker = fluid.core.has_grad_op_maker(op_desc.type()) + has_infer_inplace_in_grad_descendants = False + if not has_grad_op_maker: + has_infer_inplace_in_descendants = False + else: + # get grad_op_desc + grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( + op_desc, set(), []) + if not grad_op_desc_list: + has_infer_inplace_in_grad_descendants = False + else: + for i, grad_op_desc in enumerate(grad_op_desc_list): + if grad_op_desc.type( + ) not in visited_ops and _dfs_grad_op( + grad_op_desc, fwd_op_desc=op_desc): + has_infer_inplace_in_grad_descendants = True + if has_infer_inplace or has_infer_inplace_in_grad_descendants: + need_run_ops.append((op_desc, fwd_op_desc)) + return True + else: + return False + + _dfs_grad_op(op_desc, fwd_op_desc=fwd_op_desc) + return need_run_ops + + def _check_forward_inplace(self, + place, + no_check_set=None, + inplace_atol=None): + """Chech the inplace correctness of given op (self.op_type). + Run the op twice with same inputs, one enable inplace and another disable, compare their outputs. + + Args: + place (CPUPlace | CUDAPlace): The place where the op runs. + no_check_set (list): The names of outputs that needn't check, like XShape of reshape op. + inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op. + + Returns: + expect_res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given op. + We return this to construct grad_program and grad_feed_map for grad inplace check. + """ + # _calc_output() returns in the form tuple(outs, fetch_list, feed_map, program, op_desc) when for_inplace_test=True. expect_res = self._calc_output( place, no_check_set=no_check_set, @@ -428,7 +578,6 @@ class OpTest(unittest.TestCase): no_check_set=no_check_set, enable_inplace=True, for_inplace_test=True) - # compare expect_outs and actual_outs self._compare_expect_and_actual_outputs( place, @@ -436,160 +585,149 @@ class OpTest(unittest.TestCase): expect_res[0], actual_res[0], inplace_atol=inplace_atol) + return expect_res - # check grad - # TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn - # skip use_mkldnn currently - flags_use_mkldnn = fluid.core.get_flags_use_mkldnn() - attrs_use_mkldnn = hasattr( - self, 'attrs') and bool(self.attrs.get('use_mkldnn', False)) - if flags_use_mkldnn or attrs_use_mkldnn: - warnings.warn( - "check inplace_grad for ops using mkldnn is not supported") - return - use_ngraph = fluid.core.is_compiled_with_ngraph( - ) and fluid.core.get_flags_use_ngraph() - if use_ngraph: - warnings.warn( - "check inplace_grad for ops using ngraph is not supported") - return - - fwd_outs = expect_res[0] - fwd_fetch_list = expect_res[1] - fwd_feed_map = expect_res[2] - fwd_program = expect_res[3] - fwd_op_desc = expect_res[4] - self.check_inplace_grad_output_using_fwd_inputs_outputs( - place, - fwd_feed_map, - fwd_fetch_list, - fwd_outs, - fwd_program, - fwd_op_desc, - no_check_set=no_check_set, - inplace_atol=inplace_atol, - depth=0) - - def check_inplace_grad_output_using_fwd_inputs_outputs(self, - place, - fwd_feed_map, - fwd_fetch_list, - fwd_outs, - fwd_program, - fwd_op_desc, - no_check_set=None, - inplace_atol=None, - depth=0): - # depth=0 means grad - # depth=1 means double_grad - # depth=2 means triple_grad, which is not supported yet - if depth >= 2: - return - # get grad_op - if not fluid.core.has_grad_op_maker(fwd_op_desc.type()): - return + def _calc_grad_output(self, + place, + fwd_res, + grad_op_desc, + enable_inplace=None): + """Calculate grad_output for given grad_op_desc. + + since we don`t really check gradient accuracy, but check the consistency when using and not using inplace, + we use fwd outs (also inputs sometimes) to construct grad inputs. + + Args: + place (CPUPlace | CUDAPlace): The place where the op runs. + fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True. + i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc). + grad_op_desc (OpDesc): The OpDesc of grad op. + enable_inplace (bool): Enable inplace or not. + + Returns: + res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. + """ + fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc = fwd_res grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(fwd_op_desc, set(), []) - # has grad_op_maker but no grad_op - if not grad_op_desc_list: - return - for i, grad_op_desc in enumerate(grad_op_desc_list): - # grad_op can not inplace - if not fluid.core.has_infer_inplace(grad_op_desc.type()): - continue + grad_program = self._construct_grad_program_from_forward( + fwd_program, grad_op_desc, op_grad_to_var) + grad_feed_map = self._construct_grad_feed_map_from_forward( + place, fwd_res, grad_op_desc, op_grad_to_var) + grad_fetch_list = grad_op_desc.output_arg_names() + exe = Executor(place) + program = grad_program + if enable_inplace is not None: + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = enable_inplace + compiled_program = fluid.CompiledProgram( + grad_program).with_data_parallel( + loss_name="", build_strategy=build_strategy, places=place) + program = compiled_program + outs = exe.run(program, + feed=grad_feed_map, + fetch_list=grad_fetch_list, + return_numpy=False) + return outs, grad_fetch_list, grad_feed_map, grad_program, grad_op_desc + + def _check_grad_inplace(self, + place, + fwd_res, + grad_op_desc, + inplace_atol=None): + """Chech the inplace correctness of given grad_op_desc. + + Run the grad op twice with same inputs, one enable inplace and another disable, compare their outputs. + It works like _check_forward_inplace, but the way to construct program and feed_map differs. + So we define a new function for grad, grad_grad, etc. - # create grad program - grad_program = Program() - grad_block = grad_program.global_block() - new_op_desc = grad_block.desc.append_op() - new_op_desc.copy_from(grad_op_desc) - grad_program._sync_with_cpp() + Args: + place (CPUPlace | CUDAPlace): The place where the op runs. + fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True. + i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc). + grad_op_desc (OpDesc): The OpDesc of grad op. + inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op. - # create grad vars based on fwd vars (shape and dtype) - for arg in grad_op_desc.input_arg_names( - ) + grad_op_desc.output_arg_names(): - fwd_var_name = op_grad_to_var.get(arg, None) - if fwd_var_name is None: - fwd_var_name = arg - fwd_var = fwd_program.global_block().vars.get(fwd_var_name) - assert fwd_var is not None, "{} cannot be found".format( - fwd_var_name) - grad_var = grad_block.create_var( - name=arg, - dtype=fwd_var.dtype, - shape=fwd_var.shape, - type=fwd_var.type, - persistable=False) - # some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op, - # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]). - # set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them, - # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL. - if 0 in grad_var.shape: - grad_var.persistable = True - grad_program._sync_with_cpp() - grad_fetch_list = grad_op_desc.output_arg_names() - - # generate grad_feed_map for grad_program - # since we don`t really check gradient accuracy, but the consistency when using and not using inplace - # we use fwd outs (also inputs sometimes) as grad (fake) feeds - p = core.Place() - p.set_place(place) - grad_feed_map = {} - for arg in grad_op_desc.input_arg_names(): - if arg in fwd_feed_map.keys(): - grad_feed_map[arg] = fwd_feed_map[arg]._copy(p) - else: - fwd_var_name = op_grad_to_var.get(arg, None) - if fwd_var_name is None: - fwd_var_name = arg - - for i, out_name in enumerate(fwd_fetch_list): - if out_name == fwd_var_name: - # don't feed variables whose tensors hold no buffer (shape contains 0 like shape = [0,2,5] and holder_ is NULL), like XShape in reshape2 op. - # get them from global_scope directly since we have set them persistable in fwd execution - if 0 in fwd_program.global_block().var( - out_name).shape: - continue - else: - grad_feed_map[arg] = fwd_outs[i]._copy(p) - - def _calc_grad_output(enable_inplace=None): - exe = Executor(place) - build_strategy = fluid.BuildStrategy() - build_strategy.enable_inplace = enable_inplace - compiled_program = fluid.CompiledProgram( - grad_program).with_data_parallel( - loss_name="", - build_strategy=build_strategy, - places=place) - outs = exe.run(compiled_program, - feed=grad_feed_map, - fetch_list=grad_fetch_list, - return_numpy=False) - return outs - - expect_outs = _calc_grad_output(enable_inplace=False) - actual_outs = _calc_grad_output(enable_inplace=True) - - # compare expect_outs and actual_outs - self._compare_expect_and_actual_outputs( - place, - grad_fetch_list, - expect_outs, - actual_outs, - inplace_atol=inplace_atol) + Returns: + expect_res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given op. + We return this to construct grad_program and grad_feed_map for grad inplace check. + """ + expect_res = self._calc_grad_output( + place, fwd_res, grad_op_desc, enable_inplace=False) + actual_res = self._calc_grad_output( + place, fwd_res, grad_op_desc, enable_inplace=True) + self._compare_expect_and_actual_outputs( + place, + expect_res[1], + expect_res[0], + actual_res[0], + inplace_atol=inplace_atol) + return expect_res - # check grad of grad, recursively - self.check_inplace_grad_output_using_fwd_inputs_outputs( - place, - grad_feed_map, - grad_fetch_list, - expect_outs, - grad_program, - grad_op_desc, - no_check_set=no_check_set, - inplace_atol=inplace_atol, - depth=depth + 1) + def check_inplace_output_with_place(self, + place, + no_check_set=None, + inplace_atol=None): + """Chech the inplace correctness of given op, its grad op, its grad_grad op, etc. + + (1) Get all ops need to run. (see conditions in _get_need_run_ops()) + (2) Run op in need_run_ops, and do inplace check if it has infer_inplace. + + Args: + place (CPUPlace | CUDAPlace): The place where the op runs. + no_check_set (list): The names of outputs that needn't check, like XShape of reshape op. + inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op. + + Returns: + None + """ + has_infer_inplace = fluid.core.has_infer_inplace(self.op_type) + has_grad_op_maker = fluid.core.has_grad_op_maker(self.op_type) + + fwd_res = self._calc_output( + place, no_check_set=no_check_set, for_inplace_test=True) + op_desc = fwd_res[4] + need_run_ops = self._get_need_run_ops(op_desc) + + res = {} + for op_desc, father_op_desc in reversed(need_run_ops): + # The first one is the forward op + has_infer_inplace = fluid.core.has_infer_inplace(op_desc.type()) + if op_desc.type() == self.op_type: + if has_infer_inplace: + res[op_desc] = self._check_forward_inplace( + place, + no_check_set=no_check_set, + inplace_atol=inplace_atol) + else: + res[op_desc] = self._calc_output( + place, no_check_set=no_check_set, for_inplace_test=True) + else: + # TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn/ngraph + # skip op that use_mkldnn and use_ngraph currently + flags_use_mkldnn = fluid.core.get_flags_use_mkldnn() + attrs_use_mkldnn = hasattr( + self, + 'attrs') and bool(self.attrs.get('use_mkldnn', False)) + if flags_use_mkldnn or attrs_use_mkldnn: + warnings.warn( + "check inplace_grad for ops using mkldnn is not supported" + ) + continue + use_ngraph = fluid.core.is_compiled_with_ngraph( + ) and fluid.core.get_flags_use_ngraph() + if use_ngraph: + warnings.warn( + "check inplace_grad for ops using ngraph is not supported" + ) + continue + if has_infer_inplace: + fwd_res = res[father_op_desc] + res[op_desc] = self._check_grad_inplace( + place, fwd_res, op_desc, inplace_atol=inplace_atol) + else: + res[op_desc] = self._calc_grad_output(place, fwd_res, + op_desc) def check_output_with_place(self, place, @@ -701,6 +839,8 @@ class OpTest(unittest.TestCase): if inplace_atol is not None: warnings.warn( "By default, inplace_atol should not be set, please check it") + # Check inplace for given op, its grad op, its grad_grad op, etc. + # No effect on original OpTest self.check_inplace_output_with_place( place, no_check_set=no_check_set, inplace_atol=inplace_atol)