# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file contains composite rules of nonbasic operations. There are some notes: # 1. When define composite rule of some op, you can only use primitive ops defined in primitives.py. # 2. The name and args of target op must be corresponding with standard description of op in # ops.yaml or legacy_ops.yaml. import functools import operator import paddle.framework.dtype as dtypes from paddle.fluid import core from .primitives import * # noqa: F403 from .primreg import REGISTER_COMPOSITE, lookup_composite def _composite(op, *args): _lowerrule = lookup_composite(op.type) return _lowerrule(op, *args) @REGISTER_COMPOSITE('softmax') def softmax_composite(x, axis): """define composite rule of op softmax""" if not x.shape: # do not return 1, to ensure gradients res = exp(x - x) return res max_temp = max(x, axis, keepdim=True) max_temp.stop_gradient = True molecular = exp(x - max_temp) denominator = sum(molecular, axis=axis, keepdim=True) res = divide(molecular, denominator) return res @REGISTER_COMPOSITE('batch_norm') def composite_batchnorm( x, run_mean, run_var, scale, bias, is_test, momentum, epsilon, data_layout, use_global_stats, trainable_statistics, ): """define composite rule of op batch_norm""" is_amp = False from paddle.fluid.data_feeder import convert_dtype if convert_dtype(x.dtype) == "float16": print("Running batch_norm in amp") is_amp = True x = cast(x, "float32") feature_axis = ( 1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1 ) if use_global_stats is None: use_global_stats = is_test trainable_statistics = False else: trainable_statistics = not use_global_stats use_run_stat = (is_test and (not trainable_statistics)) or use_global_stats reduce_axes = tuple(i for i in range(len(x.shape)) if i != feature_axis) stats_shape = tuple( 1 if i in reduce_axes else s for i, s in enumerate(x.shape) ) batch_mean = zeros(run_mean.shape, run_mean.dtype) batch_var = zeros(run_var.shape, run_var.dtype) if not use_run_stat: batch_mean = mean(x, reduce_axes, keepdim=True) temp = mean(x * x, reduce_axes, keepdim=True) batch_var = temp - batch_mean * batch_mean x_hat = (x - reshape(batch_mean, stats_shape)) / sqrt( reshape(batch_var, stats_shape) + epsilon ) run_mean = momentum * run_mean + (1 - momentum) * reshape( batch_mean, run_mean.shape ) run_var = momentum * run_var + (1 - momentum) * reshape( batch_var, run_var.shape ) else: x_hat = (x - reshape(run_mean, stats_shape)) / sqrt( reshape(run_var, stats_shape) + epsilon ) y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape) if is_amp: y = cast(y, "float16") # add op assign to detach tensor in void unsafe change outside the rule. batch_mean_ = assign(reshape(batch_mean, run_mean.shape)) batch_var_ = assign(reshape(batch_var, run_var.shape)) run_mean_ = assign(run_mean) run_var_ = assign(run_var) # reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op definition. reserve_space = None return y, run_mean_, run_var_, batch_mean_, batch_var_, reserve_space @REGISTER_COMPOSITE('layer_norm') def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): """ define composite rule of op layer_norm out = (x - mean(x)) / sqrt(var + epsilon)) var = mean((x-mean(x))^2) """ axis = tuple(range(begin_norm_axis, len(x.shape))) mean_ = mean(x, axis=axis, keepdim=True) difference = x - mean_ var_tmp1 = difference * difference variance = mean(var_tmp1, axis=axis, keepdim=True) var_tmp3 = variance + epsilon sqrt_var = sqrt(var_tmp3) out = difference / sqrt_var if scale is not None: scale = reshape(scale, x.shape[begin_norm_axis:]) out = out * scale if bias is not None: bias = reshape(bias, x.shape[begin_norm_axis:]) out = out + bias mean_ = reshape(mean_, [-1]) variance = reshape(variance, [-1]) return out, mean_, variance @REGISTER_COMPOSITE('gelu') def gelu_composite(x, approximate): """define composite rule of op gelu""" M_SQRT1_2 = ( 0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc ) M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */ one = ones(x.shape, x.dtype) half = full(x.shape, 0.5, x.dtype) if approximate: # gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) kAlpha = full(x.shape, M_2_SQRTPI * M_SQRT1_2, x.dtype) GELU_CONSTANT = full(x.shape, 0.044715, x.dtype) tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) out = x * half * (one + tanh_out) return out else: # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) cdf = half * (one + erf(x * full(x.shape, M_SQRT1_2, x.dtype))) out = x * cdf return out @REGISTER_COMPOSITE('reduce_mean') def mean_composite(x, axis, keepdim): """define composite rule of op mean""" axes = axis or list(range(0, len(x.shape))) axes = [axes] if isinstance(axes, int) else axes sum_x = sum(x, axis=axes, keepdim=keepdim) value_to_fill = functools.reduce( operator.mul, [x.shape[axis] for axis in axes] ) norm = fill_constant( shape=sum_x.shape, value=value_to_fill, dtype=sum_x.dtype, ) return divide(sum_x, norm) @REGISTER_COMPOSITE('stack') def stack_composite(x, axis): """ define composite rule of op stack unsqueeze each dimension of the input (use reshape), and then concat """ x_shape = x[0].shape if axis < 0: axis += len(x_shape) + 1 out_shape = x_shape[:axis] + (1,) + x_shape[axis:] out = concat([reshape(item, out_shape) for item in x], axis) return out @REGISTER_COMPOSITE('flatten_contiguous_range') def flatten_contiguous_range_composite(x, start_axis, stop_axis): """ define composite rule of op flatten, flatten_contiguous_range -> flatten. CINN doesn't need xshape for backward pass, return none instead of xshape. shape_out is the parameter of reshape, get from start_axis and stop_axis. out = reshape(x, shape=shape_out), xshape """ shape_in = x.shape start_dim = start_axis if len(shape_in) != 0 else 0 end_dim = stop_axis if len(shape_in) != 0 else 0 assert start_dim <= end_dim if len(shape_in) == 0 or start_dim == end_dim: return reshape(x, shape=shape_in), None slice_numel = 1 for i in range(start_dim, end_dim + 1): slice_numel *= shape_in[i] shape_out = [] for i in range(start_dim): shape_out.append(shape_in[i]) shape_out.append(slice_numel) for i in range(end_dim + 1, len(shape_in)): shape_out.append(shape_in[i]) return reshape(x, shape=shape_out), None @REGISTER_COMPOSITE('dropout') def dropout_composite(x, seed_tensor, p, is_test, mode, seed, fix_seed): """define composite rule of op dropout. upscale_in_train: train: out = input * mask / ( 1.0 - p ) inference: out = input downscale_in_infer train: out = input * mask inference: out = input * (1.0 - p) """ fix_seed = True if fix_seed is None else fix_seed seed = seed if fix_seed else 0 upscale_in_train = mode == "upscale_in_train" mask = bernoulli(shape=x.shape, dtype=x.dtype, p=p, seed=seed) if upscale_in_train: if not is_test: # Process p=1.0 for avoid devide zero error (x*mask/(1.0-p)) if p == 1.0: return 0.0 * x, zeros(x.shape, core.VarDesc.VarType.UINT8) else: return x * mask / (1.0 - p), cast( mask, core.VarDesc.VarType.UINT8 ) else: return assign(x), cast(mask, core.VarDesc.VarType.UINT8) else: if not is_test: return x * mask, cast(mask, core.VarDesc.VarType.UINT8) else: return x * (1.0 - p), cast(mask, core.VarDesc.VarType.UINT8) def bernoulli(shape, dtype, p, seed=0): return cast( greater_equal( uniform(shape, dtype, min=0.0, max=1.0, seed=seed), fill_constant(shape, dtype, p), ), dtype, ) @REGISTER_COMPOSITE('hard_swish') def hard_swish_composite(x): """define composite rule of op hard_swish. offset=3, threshold=6, scale=6 out = minimum( maxmum(x + offset, 0), threshold ) * x / scale """ offset = 3.0 threshold = 6.0 scale = 6.0 res = ( minimum( maximum( x + full(x.shape, offset, dtype=x.dtype), full(x.shape, 0.0, dtype=x.dtype), ), full(x.shape, threshold, dtype=x.dtype), ) * x / full(x.shape, scale, dtype=x.dtype) ) return res @REGISTER_COMPOSITE('sigmoid') def sigmoid_composite(x): """ define composite rule of op sigmoid res = 1 / (1 + exp(-x)) """ sum_temp = 1 + exp(-x) res = 1 / sum_temp return res @REGISTER_COMPOSITE('silu') def silu_composite(x): """ define composite rule of op silu res = x / (1 + exp(-x)) """ sum_temp = 1 + exp(-x) res = x / sum_temp return res @REGISTER_COMPOSITE('fill_any_like') def fill_any_like(x, fill_value, dtype, place=None): """define composite rule of op full_like.""" """op name: full_like op type name: fill_any_like.""" """arg place is not used, add it here to keep same as python api.""" dtype = dtypes.dtype(dtype) val = full(x.shape, fill_value, dtype) return val @REGISTER_COMPOSITE('relu') def relu_composite(x): """define composite rule of op relu.""" # relu(x) = max(x, 0) return maximum(x, zeros_like(x))