提交 86ba9362 编写于 作者: W wandongdong

split correction_mul op

上级 445122f5
......@@ -18,6 +18,7 @@
from .. import operations as P
from .grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ... import context
@bprop_getters.register(P.FakeQuantPerLayer)
......@@ -64,12 +65,21 @@ def get_bprop_batchnorm_fold(self):
@bprop_getters.register(P.CorrectionMul)
def get_bprop_correction_mul(self):
"""Generate bprop for CorrectionMul for Ascend and GPU"""
grad = P.CorrectionMulGrad(self.channel_axis)
grad_dx = P.CorrectionMulGrad(self.channel_axis)
grad_d_batch_std = P.CorrectionMulGradReduce(self.channel_axis)
def bprop(x, batch_std, running_std, out, dout):
dx, d_batch_std = grad(dout, x, batch_std, running_std)
dx, d_batch_std = grad_dx(dout, x, batch_std, running_std)
return dx, d_batch_std, zeros_like(running_std)
def bprop_npu(x, batch_std, running_std, out, dout):
dx, mul_dx = grad_dx(dout, x, batch_std, running_std)
d_batch_std = grad_d_batch_std(mul_dx)
return dx, d_batch_std, zeros_like(running_std)
if context.get_context('device_target') == "Ascend":
return bprop_npu
return bprop
......
......@@ -37,7 +37,7 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.input(2, "batch_std", None, "required", None) \
.input(3, "running_std", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.output(1, "d_batch_std", True, "required", "all") \
.output(1, "mul_dx", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......@@ -56,21 +56,14 @@ def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_f
factor = te.lang.cce.vdiv(batch_std, running_std)
factor_b = te.lang.cce.broadcast(factor, shape_x)
dx = te.lang.cce.vmul(dout, factor_b)
mul_data = te.lang.cce.vmul(dout, x)
if channel == 0:
if data_format == "NCHW":
axis = [1, 2, 3]
else:
axis = [1, 2, 3, 4]
else:
axis = [2, 3]
red_data = te.lang.cce.sum(mul_data, axis, keepdims=True)
d_batch_std = te.lang.cce.vdiv(red_data, running_std)
return [dx, d_batch_std]
mul_dx = te.lang.cce.vmul(dout, x)
running_std_b = te.lang.cce.broadcast(running_std, shape_x)
mul_dx = te.lang.cce.vdiv(mul_dx, running_std_b)
return [dx, mul_dx]
@util.check_input_type(dict, dict, dict, dict, dict, dict, int, str)
def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"):
def correction_mul_grad(dout, x, batch_std, running_std, dx, mul_dx, channel, kernel_name="correction_mul_grad"):
"""CorrectionMulGrad op"""
shape_dout = dout.get("shape")
shape_x = dout.get("shape")
......@@ -93,7 +86,7 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
util.compare_tensor_dict_key(dout, x, "shape")
util.compare_tensor_dict_key(dx, x, "shape")
util.compare_tensor_dict_key(batch_std, running_std, "shape")
util.compare_tensor_dict_key(batch_std, d_batch_std, "shape")
util.compare_tensor_dict_key(dx, mul_dx, "shape")
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_x)
......@@ -120,7 +113,84 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list)
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + res_list
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
te.lang.cce.cce_build_code(sch, config)
correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("correction_mul_grad_reduce.so") \
.compute_cost(10) \
.kernel_name("correction_mul_grad_reduce") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \
.output(0, "d_batch_std", True, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(correction_mul_grad_reduce_op_info)
def _correction_mul_grad_reduce_tbe():
"""CorrectionMulGradReduce TBE register"""
return
@fusion_manager.register("correction_mul_grad_reduce")
def correction_mul_grad_reduce_compute(mul_dx, channel, data_format, kernel_name="correction_mul"):
"""CorrectionMulGradReduce compute"""
if channel == 0:
if data_format == "NCHW":
axis = [1, 2, 3]
else:
axis = [1, 2, 3, 4]
else:
axis = [2, 3]
d_batch_std = te.lang.cce.sum(mul_dx, axis, keepdims=True)
return d_batch_std
@util.check_input_type(dict, dict, int, str)
def correction_mul_grad_reduce(mul_dx, d_batch_std, channel, kernel_name="correction_mul_grad_reduce"):
"""CorrectionMulGradReduce op"""
shape_dout = mul_dx.get("shape")
shape_x = mul_dx.get("shape")
dtype_dout = mul_dx.get("dtype")
inp_dtype_dout = dtype_dout.lower()
util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
util.check_kernel_name(kernel_name)
util.check_shape_rule(shape_x)
util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
data_format = mul_dx.get("format")
ori_format = mul_dx.get("format")
if data_format.upper() not in ("NC1HWC0", "NCHW"):
raise RuntimeError("Un supported data format {}".format(data_format))
if data_format.upper() == "NCHW" and ori_format != "NCHW":
raise RuntimeError("data_format(NCHW) must same as ori_format")
shape_c = [1] * len(shape_x)
shape_c[channel] = d_batch_std.get("ori_shape")[0]
if data_format == "NC1HWC0" and channel == 1:
shape_c = d_batch_std.get("shape")
dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout)
res = correction_mul_grad_reduce_compute(dout_t, channel, data_format, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)
tensor_list = [dout_t, res]
config = {"print_ir": False,
"name": kernel_name,
"tensor_list": tensor_list}
......
......@@ -31,6 +31,7 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFoldGrad",
"CorrectionMul",
"CorrectionMulGrad",
"CorrectionMulGradReduce",
"BatchNormFold2",
"BatchNormFold2Grad",
"BatchNormFoldD",
......@@ -500,7 +501,7 @@ class CorrectionMulGrad(PrimitiveWithInfer):
from mindspore.ops._op_impl._custom_op import correction_mul_grad
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
outputs=['dx', 'd_gamma'])
outputs=['dx', 'mul_dx'])
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
......@@ -508,12 +509,45 @@ class CorrectionMulGrad(PrimitiveWithInfer):
Rel.EQ, self.name)
validator.check("running_std_shape[0]", running_std_shape[0],
"dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
if context.get_context('device_target') == "Ascend":
return x_shape, x_shape
return x_shape, gamma_shape
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
return x_type, x_type
if context.get_context('device_target') == "Ascend":
return x_type, x_type
return x_type, gamma_type
class CorrectionMulGradReduce(PrimitiveWithInfer):
r"""
Performs grad reduce of CorrectionMul operation.
Examples:
>>> correction_mul_grad_rd = P.CorrectionMulGradReduce()
>>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
>>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
>>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
>>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
>>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std)
"""
@prim_attr_register
def __init__(self, channel_axis=0):
"""init correction mul reduce layer"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import correction_mul_grad
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['mul_dx'],
outputs=['d_gamma'])
def infer_shape(self, mul_dx_shape):
return [mul_dx_shape[self.channel_axis]]
def infer_dtype(self, mul_dx_type):
return mul_dx_type
class BatchNormFold2(PrimitiveWithInfer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册