diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index 7ab9192040977cc172b2306364f2731abe64b23d..c1a272712dfdfc6a0c17bb5ba28e532c0bc248e3 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -18,6 +18,7 @@ from .. import operations as P from .grad_base import bprop_getters from ..composite.multitype_ops.zeros_like_impl import zeros_like +from ... import context @bprop_getters.register(P.FakeQuantPerLayer) @@ -64,12 +65,21 @@ def get_bprop_batchnorm_fold(self): @bprop_getters.register(P.CorrectionMul) def get_bprop_correction_mul(self): """Generate bprop for CorrectionMul for Ascend and GPU""" - grad = P.CorrectionMulGrad(self.channel_axis) + grad_dx = P.CorrectionMulGrad(self.channel_axis) + grad_d_batch_std = P.CorrectionMulGradReduce(self.channel_axis) def bprop(x, batch_std, running_std, out, dout): - dx, d_batch_std = grad(dout, x, batch_std, running_std) + dx, d_batch_std = grad_dx(dout, x, batch_std, running_std) return dx, d_batch_std, zeros_like(running_std) + def bprop_npu(x, batch_std, running_std, out, dout): + dx, mul_dx = grad_dx(dout, x, batch_std, running_std) + d_batch_std = grad_d_batch_std(mul_dx) + return dx, d_batch_std, zeros_like(running_std) + + if context.get_context('device_target') == "Ascend": + return bprop_npu + return bprop diff --git a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py index e79217f5213db8621cf68f343cf96f2f7055369d..da3a634454add6c0505406bc85722f3d5be64d0e 100644 --- a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py @@ -37,7 +37,7 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \ .input(2, "batch_std", None, "required", None) \ .input(3, "running_std", None, "required", None) \ .output(0, "dx", True, "required", "all") \ - .output(1, "d_batch_std", True, "required", "all") \ + .output(1, "mul_dx", True, "required", "all") \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() @@ -56,21 +56,14 @@ def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_f factor = te.lang.cce.vdiv(batch_std, running_std) factor_b = te.lang.cce.broadcast(factor, shape_x) dx = te.lang.cce.vmul(dout, factor_b) - mul_data = te.lang.cce.vmul(dout, x) - if channel == 0: - if data_format == "NCHW": - axis = [1, 2, 3] - else: - axis = [1, 2, 3, 4] - else: - axis = [2, 3] - red_data = te.lang.cce.sum(mul_data, axis, keepdims=True) - d_batch_std = te.lang.cce.vdiv(red_data, running_std) - return [dx, d_batch_std] + mul_dx = te.lang.cce.vmul(dout, x) + running_std_b = te.lang.cce.broadcast(running_std, shape_x) + mul_dx = te.lang.cce.vdiv(mul_dx, running_std_b) + return [dx, mul_dx] @util.check_input_type(dict, dict, dict, dict, dict, dict, int, str) -def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"): +def correction_mul_grad(dout, x, batch_std, running_std, dx, mul_dx, channel, kernel_name="correction_mul_grad"): """CorrectionMulGrad op""" shape_dout = dout.get("shape") shape_x = dout.get("shape") @@ -93,7 +86,7 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe util.compare_tensor_dict_key(dout, x, "shape") util.compare_tensor_dict_key(dx, x, "shape") util.compare_tensor_dict_key(batch_std, running_std, "shape") - util.compare_tensor_dict_key(batch_std, d_batch_std, "shape") + util.compare_tensor_dict_key(dx, mul_dx, "shape") util.check_kernel_name(kernel_name) util.check_shape_rule(shape_x) @@ -120,7 +113,84 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe with tvm.target.cce(): sch = generic.auto_schedule(res_list) - tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list) + tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + res_list + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(sch, config) + + +correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("correction_mul_grad_reduce.so") \ + .compute_cost(10) \ + .kernel_name("correction_mul_grad_reduce") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .attr("channel_axis", "optional", "int", "all") \ + .input(0, "dout", None, "required", None) \ + .output(0, "d_batch_std", True, "required", "all") \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(correction_mul_grad_reduce_op_info) +def _correction_mul_grad_reduce_tbe(): + """CorrectionMulGradReduce TBE register""" + return + + +@fusion_manager.register("correction_mul_grad_reduce") +def correction_mul_grad_reduce_compute(mul_dx, channel, data_format, kernel_name="correction_mul"): + """CorrectionMulGradReduce compute""" + if channel == 0: + if data_format == "NCHW": + axis = [1, 2, 3] + else: + axis = [1, 2, 3, 4] + else: + axis = [2, 3] + d_batch_std = te.lang.cce.sum(mul_dx, axis, keepdims=True) + return d_batch_std + + +@util.check_input_type(dict, dict, int, str) +def correction_mul_grad_reduce(mul_dx, d_batch_std, channel, kernel_name="correction_mul_grad_reduce"): + """CorrectionMulGradReduce op""" + shape_dout = mul_dx.get("shape") + shape_x = mul_dx.get("shape") + + dtype_dout = mul_dx.get("dtype") + + inp_dtype_dout = dtype_dout.lower() + + util.check_dtype_rule(inp_dtype_dout, ("float16", "float32")) + + util.check_kernel_name(kernel_name) + util.check_shape_rule(shape_x) + util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT) + + data_format = mul_dx.get("format") + ori_format = mul_dx.get("format") + if data_format.upper() not in ("NC1HWC0", "NCHW"): + raise RuntimeError("Un supported data format {}".format(data_format)) + if data_format.upper() == "NCHW" and ori_format != "NCHW": + raise RuntimeError("data_format(NCHW) must same as ori_format") + + shape_c = [1] * len(shape_x) + shape_c[channel] = d_batch_std.get("ori_shape")[0] + if data_format == "NC1HWC0" and channel == 1: + shape_c = d_batch_std.get("shape") + + dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout) + res = correction_mul_grad_reduce_compute(dout_t, channel, data_format, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [dout_t, res] config = {"print_ir": False, "name": kernel_name, "tensor_list": tensor_list} diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index a6abb45b7d8733ae2c67d2dfa50049d1fa10a442..05dcf53ef33d0f5ad8566abcbd5c61bbf53664d9 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -31,10 +31,12 @@ __all__ = ["FakeQuantPerLayer", "BatchNormFoldGrad", "CorrectionMul", "CorrectionMulGrad", + "CorrectionMulGradReduce", "BatchNormFold2", "BatchNormFold2Grad", "BatchNormFoldD", "BatchNormFoldGradD", + "BNTrainingReduce", "BatchNormFold2_D", "BatchNormFold2GradD", "BatchNormFold2GradReduce", @@ -332,7 +334,7 @@ class BatchNormFold(PrimitiveWithInfer): Batch normalization folded. Args: - momentum (float): Momentum value should be [0, 1]. Default: 0.9. + momentum (float): Momentum value should be [0, 1]. Default: 0.1. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in float32 else 1e-3. Default: 1e-5. is_training (bool): In training mode set True, else set False. Default: True. @@ -364,7 +366,7 @@ class BatchNormFold(PrimitiveWithInfer): channel_axis = 1 @prim_attr_register - def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): + def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): """init batch norm fold layer""" self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) @@ -499,7 +501,7 @@ class CorrectionMulGrad(PrimitiveWithInfer): from mindspore.ops._op_impl._custom_op import correction_mul_grad self.channel_axis = channel_axis self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], - outputs=['dx', 'd_gamma']) + outputs=['dx', 'mul_dx']) def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name) @@ -507,12 +509,45 @@ class CorrectionMulGrad(PrimitiveWithInfer): Rel.EQ, self.name) validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name) + if context.get_context('device_target') == "Ascend": + return x_shape, x_shape return x_shape, gamma_shape def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) - return x_type, x_type + if context.get_context('device_target') == "Ascend": + return x_type, x_type + return x_type, gamma_type + + +class CorrectionMulGradReduce(PrimitiveWithInfer): + r""" + Performs grad reduce of CorrectionMul operation. + + Examples: + >>> correction_mul_grad_rd = P.CorrectionMulGradReduce() + >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32) + >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32) + >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32) + >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32) + >>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std) + """ + + @prim_attr_register + def __init__(self, channel_axis=0): + """init correction mul reduce layer""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import correction_mul_grad + self.channel_axis = channel_axis + self.init_prim_io_names(inputs=['mul_dx'], + outputs=['d_gamma']) + + def infer_shape(self, mul_dx_shape): + return [mul_dx_shape[self.channel_axis]] + + def infer_dtype(self, mul_dx_type): + return mul_dx_type class BatchNormFold2(PrimitiveWithInfer): @@ -696,6 +731,32 @@ class BatchNormFoldGradD(PrimitiveWithInfer): return x_type +class BNTrainingReduce(PrimitiveWithInfer): + """ + reduce sum at axis [0, 2, 3]. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C)`. + + Outputs: + - **x_sum** (Tensor) - Tensor has the same shape as x. + - **x_square_sum** (Tensor) - Tensor has the same shape as x. + + """ + + @prim_attr_register + def __init__(self): + """init _BNTrainingReduce layer""" + self.init_prim_io_names(inputs=['x'], + outputs=['x_sum', 'x_square_sum']) + + def infer_shape(self, x_shape): + return [x_shape[1]], [x_shape[1]] + + def infer_dtype(self, x_type): + return x_type, x_type + + class BatchNormFold2_D(PrimitiveWithInfer): """ Scale the bias with a correction factor to the long term statistics