未验证 提交 4da9b87b 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] fix functool.reduce more safe with intial value, to support empty list (#53182)

上级 f424162c
...@@ -252,7 +252,7 @@ class CostEstimator: ...@@ -252,7 +252,7 @@ class CostEstimator:
def _calculate_bytes(self, sizes, dtype): def _calculate_bytes(self, sizes, dtype):
if sizes: if sizes:
total_count = reduce(lambda x, y: x * y, sizes) total_count = reduce(lambda x, y: x * y, sizes, 1)
else: else:
total_count = 0 total_count = 0
......
...@@ -96,7 +96,7 @@ class TensorCost: ...@@ -96,7 +96,7 @@ class TensorCost:
shape = self.shape shape = self.shape
dtype = self.dtype dtype = self.dtype
total_count = reduce(lambda x, y: x * y, shape) total_count = reduce(lambda x, y: x * y, shape, 1)
if dtype == paddle.float32 or dtype == paddle.int32: if dtype == paddle.float32 or dtype == paddle.int32:
dtype_factor = 4 dtype_factor = 4
......
...@@ -336,7 +336,7 @@ class PlanSpace: ...@@ -336,7 +336,7 @@ class PlanSpace:
ops = program.global_block().ops ops = program.global_block().ops
vars = program.global_block().vars vars = program.global_block().vars
processes = reduce(lambda x, y: x * y, process_mesh_topology) processes = reduce(lambda x, y: x * y, process_mesh_topology, 1)
global_group = list(range(processes)) global_group = list(range(processes))
global_process_mesh = None global_process_mesh = None
pipeline_process_meshes = None pipeline_process_meshes = None
......
...@@ -1120,7 +1120,7 @@ class Resharder: ...@@ -1120,7 +1120,7 @@ class Resharder:
"""Compute the index of process_shape corresponding to the process.""" """Compute the index of process_shape corresponding to the process."""
relative_process = process_group.index(process) relative_process = process_group.index(process)
process_index = [] process_index = []
product = reduce(lambda x, y: x * y, process_shape) product = reduce(lambda x, y: x * y, process_shape, 1)
for i in range(len(process_shape)): for i in range(len(process_shape)):
idx = relative_process // (product // process_shape[i]) idx = relative_process // (product // process_shape[i])
......
...@@ -2120,7 +2120,7 @@ class RuleBasedTuner: ...@@ -2120,7 +2120,7 @@ class RuleBasedTuner:
has_used_devices = 0 has_used_devices = 0
self.device_meshes_list.append([]) self.device_meshes_list.append([])
for device_mesh in device_meshes: for device_mesh in device_meshes:
devices = reduce(lambda x, y: x * y, device_mesh) devices = reduce(lambda x, y: x * y, device_mesh, 1)
processes = list( processes = list(
range(has_used_devices, has_used_devices + devices) range(has_used_devices, has_used_devices + devices)
) )
......
...@@ -1684,7 +1684,7 @@ def get_standalone_cost_data(distributed_programs): ...@@ -1684,7 +1684,7 @@ def get_standalone_cost_data(distributed_programs):
].split(",") ].split(",")
shape = [int(x.strip()) for x in shape] shape = [int(x.strip()) for x in shape]
dtype_factor = 1 dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape) total_static_input_size += reduce(lambda x, y: x * y, shape, 1)
if op.type == "c_embedding": if op.type == "c_embedding":
arg_name_lower = ( arg_name_lower = (
"w" if arg_name_lower == "weight" else "ids" "w" if arg_name_lower == "weight" else "ids"
...@@ -1838,7 +1838,7 @@ def get_var_numel(var): ...@@ -1838,7 +1838,7 @@ def get_var_numel(var):
""" """
assert isinstance(var, Variable) assert isinstance(var, Variable)
assert -1 not in var.shape assert -1 not in var.shape
return reduce(lambda x, y: x * y, var.shape) return reduce(lambda x, y: x * y, var.shape, 1)
def get_lr(optimizer): def get_lr(optimizer):
......
...@@ -62,7 +62,7 @@ class CommunicateTopology: ...@@ -62,7 +62,7 @@ class CommunicateTopology:
self.coordinate = collections.namedtuple( self.coordinate = collections.namedtuple(
'Coordinate', self._parallel_names 'Coordinate', self._parallel_names
) )
self._world_size = reduce(lambda x, y: x * y, self._dims) self._world_size = reduce(lambda x, y: x * y, self._dims, 1)
ranges = [range(d) for d in self._dims] ranges = [range(d) for d in self._dims]
all_coordinate = [self.coordinate(*x) for x in product(*ranges)] all_coordinate = [self.coordinate(*x) for x in product(*ranges)]
......
...@@ -113,7 +113,7 @@ class DGCMomentumOptimizer(Optimizer): ...@@ -113,7 +113,7 @@ class DGCMomentumOptimizer(Optimizer):
return regular_type, regular_coeff return regular_type, regular_coeff
def _is_use_dgc(self, param_var, grad_var): def _is_use_dgc(self, param_var, grad_var):
var_numel = abs(reduce(lambda x, y: x * y, param_var.shape)) var_numel = abs(reduce(lambda x, y: x * y, param_var.shape, 1))
if ( if (
var_numel < 16384 var_numel < 16384
or param_var.type == core.VarDesc.VarType.SELECTED_ROWS or param_var.type == core.VarDesc.VarType.SELECTED_ROWS
......
...@@ -111,7 +111,7 @@ class DygraphShardingOptimizer: ...@@ -111,7 +111,7 @@ class DygraphShardingOptimizer:
for param in self._parameter_list: for param in self._parameter_list:
rank = sizes.index(min(sizes)) rank = sizes.index(min(sizes))
mapping[rank].append(param) mapping[rank].append(param)
numel = reduce(lambda x, y: x * y, param.shape) numel = reduce(lambda x, y: x * y, param.shape, 1)
assert ( assert (
numel > 0 numel > 0
), "param [{}] should larger than 0, but it is [{}]".format( ), "param [{}] should larger than 0, but it is [{}]".format(
......
...@@ -898,7 +898,7 @@ def get_var_size(param): ...@@ -898,7 +898,7 @@ def get_var_size(param):
""" """
assert -1 not in param.shape assert -1 not in param.shape
return ( return (
reduce(lambda x, y: x * y, param.shape) reduce(lambda x, y: x * y, param.shape, 1)
* DtypeToSize[param.dtype] * DtypeToSize[param.dtype]
/ 1024.0 / 1024.0
/ 1024.0 / 1024.0
......
...@@ -75,8 +75,8 @@ def _get_dpmp_topology(origin_topology, sharding_group): ...@@ -75,8 +75,8 @@ def _get_dpmp_topology(origin_topology, sharding_group):
sharding_axis = 0 sharding_axis = 0
dp_sharding_topology = dp_sharding_topology[1:] dp_sharding_topology = dp_sharding_topology[1:]
product_dp_sharding = reduce(lambda x, y: x * y, dp_sharding_topology) product_dp_sharding = reduce(lambda x, y: x * y, dp_sharding_topology, 1)
product_topology = reduce(lambda x, y: x * y, origin_topology) product_topology = reduce(lambda x, y: x * y, origin_topology, 1)
if product_topology == product_dp_sharding: if product_topology == product_dp_sharding:
dpmp_topology = dp_sharding_topology dpmp_topology = dp_sharding_topology
...@@ -274,7 +274,7 @@ class ClipHelper: ...@@ -274,7 +274,7 @@ class ClipHelper:
for param in params: for param in params:
rank = sizes.index(min(sizes)) rank = sizes.index(min(sizes))
mapping[rank].append(param.name) mapping[rank].append(param.name)
numel = reduce(lambda x, y: x * y, param.shape) numel = reduce(lambda x, y: x * y, param.shape, 1)
assert ( assert (
numel > 0 numel > 0
), "param [{}] should larger than 0, but it is [{}]".format( ), "param [{}] should larger than 0, but it is [{}]".format(
......
...@@ -1661,7 +1661,7 @@ def partition_by_greedy_even(params, group_size): ...@@ -1661,7 +1661,7 @@ def partition_by_greedy_even(params, group_size):
for param in params: for param in params:
rank = sizes.index(min(sizes)) rank = sizes.index(min(sizes))
mapping[rank].append(param) mapping[rank].append(param)
numel = reduce(lambda x, y: x * y, param.shape) numel = reduce(lambda x, y: x * y, param.shape, 1)
assert ( assert (
numel > 0 numel > 0
), "param [{}] should larger than 0, but it is [{}]".format( ), "param [{}] should larger than 0, but it is [{}]".format(
......
...@@ -386,7 +386,7 @@ def get_dense_send_context( ...@@ -386,7 +386,7 @@ def get_dense_send_context(
grad = merged[1] grad = merged[1]
origin_varnames.append(grad.merged_var.name) origin_varnames.append(grad.merged_var.name)
var = program.global_block().vars[grad.merged_var.name] var = program.global_block().vars[grad.merged_var.name]
var_numel += reduce(lambda x, y: x * y, var.shape) var_numel += reduce(lambda x, y: x * y, var.shape, 1)
grad_name = "Dense@GRAD_" + str(idx) grad_name = "Dense@GRAD_" + str(idx)
aggregate = True aggregate = True
# print("public get_dense_send_context dense_table:", grad_name, # print("public get_dense_send_context dense_table:", grad_name,
...@@ -422,7 +422,7 @@ def get_dense_send_context( ...@@ -422,7 +422,7 @@ def get_dense_send_context(
grad = merged[1] grad = merged[1]
origin_varnames.append(grad.merged_var.name) origin_varnames.append(grad.merged_var.name)
var = program.global_block().vars[grad.merged_var.name] var = program.global_block().vars[grad.merged_var.name]
var_numel += reduce(lambda x, y: x * y, var.shape) var_numel += reduce(lambda x, y: x * y, var.shape, 1)
grad_name = "DataNorm@GRAD_" + str(idx) grad_name = "DataNorm@GRAD_" + str(idx)
aggregate = True aggregate = True
# print("public get_dense_send_context data_norm table:", grad_name, # print("public get_dense_send_context data_norm table:", grad_name,
...@@ -452,7 +452,7 @@ def get_dense_send_context( ...@@ -452,7 +452,7 @@ def get_dense_send_context(
grad = merged[1] grad = merged[1]
origin_varname = grad.merged_var.name origin_varname = grad.merged_var.name
var = program.global_block().vars[origin_varname] var = program.global_block().vars[origin_varname]
var_numel = reduce(lambda x, y: x * y, var.shape) var_numel = reduce(lambda x, y: x * y, var.shape, 1)
grad_name = origin_varname grad_name = origin_varname
aggregate = True aggregate = True
from paddle.fluid.core import CommContext from paddle.fluid.core import CommContext
...@@ -503,7 +503,7 @@ def get_geo_trainer_send_context(attrs): ...@@ -503,7 +503,7 @@ def get_geo_trainer_send_context(attrs):
True if param_name in distibuted_varnames else False True if param_name in distibuted_varnames else False
) )
var = program.global_block().vars[grad.merged_var.name] var = program.global_block().vars[grad.merged_var.name]
var_numel = reduce(lambda x, y: x * y, var.shape[1:]) var_numel = reduce(lambda x, y: x * y, var.shape[1:], 1)
from paddle.fluid.core import CommContext from paddle.fluid.core import CommContext
print( print(
...@@ -1167,7 +1167,7 @@ def get_communicate_var_info( ...@@ -1167,7 +1167,7 @@ def get_communicate_var_info(
for name in entrance_var_list: for name in entrance_var_list:
var = program.global_block().vars[name] var = program.global_block().vars[name]
shape = var.shape shape = var.shape
recv_var_dim = -1 * reduce(lambda x, y: x * y, shape) recv_var_dim = -1 * reduce(lambda x, y: x * y, shape, 1)
input_var_reshape_dim.append(recv_var_dim) input_var_reshape_dim.append(recv_var_dim)
input_var_reshape_name.append(f"{name}.input_reshape@Heter") input_var_reshape_name.append(f"{name}.input_reshape@Heter")
...@@ -1448,7 +1448,7 @@ dtype_to_size = { ...@@ -1448,7 +1448,7 @@ dtype_to_size = {
def get_var_mem_size(var): def get_var_mem_size(var):
m_size = reduce(lambda x, y: x * y, var.shape) m_size = reduce(lambda x, y: x * y, var.shape, 1)
m_size *= dtype_to_size[var.dtype] m_size *= dtype_to_size[var.dtype]
return m_size return m_size
......
...@@ -117,7 +117,7 @@ def slice_variable(var_list, slice_count, min_block_size): ...@@ -117,7 +117,7 @@ def slice_variable(var_list, slice_count, min_block_size):
blocks = [] blocks = []
for var in var_list: for var in var_list:
split_count = slice_count split_count = slice_count
var_numel = reduce(lambda x, y: x * y, var.shape) var_numel = reduce(lambda x, y: x * y, var.shape, 1)
max_pserver_count = int(math.floor(var_numel / float(min_block_size))) max_pserver_count = int(math.floor(var_numel / float(min_block_size)))
if max_pserver_count == 0: if max_pserver_count == 0:
max_pserver_count = 1 max_pserver_count = 1
...@@ -127,7 +127,7 @@ def slice_variable(var_list, slice_count, min_block_size): ...@@ -127,7 +127,7 @@ def slice_variable(var_list, slice_count, min_block_size):
if len(var.shape) >= 2: if len(var.shape) >= 2:
# align by dim1(width) # align by dim1(width)
dim1 = reduce(lambda x, y: x * y, var.shape[1:]) dim1 = reduce(lambda x, y: x * y, var.shape[1:], 1)
remains = block_size % dim1 remains = block_size % dim1
if remains != 0: if remains != 0:
block_size += dim1 - remains block_size += dim1 - remains
...@@ -2286,7 +2286,9 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler ...@@ -2286,7 +2286,9 @@ WIKI: https://github.com/PaddlePaddle/Fleet/blob/develop/markdown_doc/transpiler
orig_shape = orig_var.shape orig_shape = orig_var.shape
orig_dim1_flatten = 1 orig_dim1_flatten = 1
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:]) orig_dim1_flatten = reduce(
lambda x, y: x * y, orig_shape[1:], 1
)
for i, block in enumerate(split): for i, block in enumerate(split):
size = block[1] size = block[1]
......
...@@ -5968,7 +5968,7 @@ class PipelineOptimizer: ...@@ -5968,7 +5968,7 @@ class PipelineOptimizer:
} }
assert -1 not in var.shape assert -1 not in var.shape
return ( return (
reduce(lambda x, y: x * y, var.shape) reduce(lambda x, y: x * y, var.shape, 1)
* dtype_to_size[var.dtype] * dtype_to_size[var.dtype]
/ 1024.0 / 1024.0
/ 1024.0 / 1024.0
......
...@@ -46,7 +46,7 @@ class TestDygraphWeightNorm(unittest.TestCase): ...@@ -46,7 +46,7 @@ class TestDygraphWeightNorm(unittest.TestCase):
def norm_except_dim(self, w, dim=None): def norm_except_dim(self, w, dim=None):
shape = w.shape shape = w.shape
ndims = len(shape) ndims = len(shape)
shape_numel = reduce(lambda x, y: x * y, shape) shape_numel = reduce(lambda x, y: x * y, shape, 1)
if dim == -1: if dim == -1:
return np.linalg.norm(w, axis=None, keepdims=True).flatten() return np.linalg.norm(w, axis=None, keepdims=True).flatten()
elif dim == 0: elif dim == 0:
...@@ -68,7 +68,7 @@ class TestDygraphWeightNorm(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestDygraphWeightNorm(unittest.TestCase):
def weight_normalize(self, w, dim=None): def weight_normalize(self, w, dim=None):
shape = w.shape shape = w.shape
ndims = len(shape) ndims = len(shape)
shape_numel = reduce(lambda x, y: x * y, shape) shape_numel = reduce(lambda x, y: x * y, shape, 1)
v = w v = w
g = self.norm_except_dim(w, dim) g = self.norm_except_dim(w, dim)
g_mul = g g_mul = g
......
...@@ -1427,7 +1427,7 @@ class TestGradientTruncated(unittest.TestCase): ...@@ -1427,7 +1427,7 @@ class TestGradientTruncated(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
to_string = lambda x, i: x + '_' + str(i) to_string = lambda x, i: x + '_' + str(i)
numel = lambda input_shape: reduce(lambda x, y: x * y, input_shape) numel = lambda input_shape: reduce(lambda x, y: x * y, input_shape, 1)
def op1(x): def op1(x):
value = paddle.tensor.fill_constant([1], "float32", 1) value = paddle.tensor.fill_constant([1], "float32", 1)
......
...@@ -612,7 +612,7 @@ class TestListIndex(unittest.TestCase): ...@@ -612,7 +612,7 @@ class TestListIndex(unittest.TestCase):
np.random.seed(2022) np.random.seed(2022)
def numel(self, shape): def numel(self, shape):
return reduce(lambda x, y: x * y, shape) return reduce(lambda x, y: x * y, shape, 1)
def test_static_graph_list_index(self): def test_static_graph_list_index(self):
paddle.enable_static() paddle.enable_static()
......
...@@ -117,7 +117,7 @@ class SliceInfo: ...@@ -117,7 +117,7 @@ class SliceInfo:
return s return s
def numel(self, shape): def numel(self, shape):
return reduce(lambda x, y: x * y, shape) return reduce(lambda x, y: x * y, shape, 1)
def get_offset_stride(self, tensor_shape): def get_offset_stride(self, tensor_shape):
for index in self.indexes: for index in self.indexes:
......
...@@ -652,7 +652,7 @@ class CompileTimeStrategy: ...@@ -652,7 +652,7 @@ class CompileTimeStrategy:
var = self.origin_main_program.global_block().vars[ var = self.origin_main_program.global_block().vars[
grad.merged_var.name grad.merged_var.name
] ]
var_numel = reduce(lambda x, y: x * y, var.shape[1:]) var_numel = reduce(lambda x, y: x * y, var.shape[1:], 1)
sparse_ctx = core.CommContext( sparse_ctx = core.CommContext(
grad_name, grad_name,
...@@ -705,7 +705,7 @@ class CompileTimeStrategy: ...@@ -705,7 +705,7 @@ class CompileTimeStrategy:
var = self.origin_main_program.global_block().vars[ var = self.origin_main_program.global_block().vars[
grad.merged_var.name grad.merged_var.name
] ]
var_numel += reduce(lambda x, y: x * y, var.shape) var_numel += reduce(lambda x, y: x * y, var.shape, 1)
grad_name = "Dense@Grad" grad_name = "Dense@Grad"
trainer_id = self.get_role_id() trainer_id = self.get_role_id()
aggregate = True aggregate = True
...@@ -734,7 +734,7 @@ class CompileTimeStrategy: ...@@ -734,7 +734,7 @@ class CompileTimeStrategy:
var = self.origin_main_program.global_block().vars[ var = self.origin_main_program.global_block().vars[
origin_varname origin_varname
] ]
var_numel = reduce(lambda x, y: x * y, var.shape) var_numel = reduce(lambda x, y: x * y, var.shape, 1)
grad_name = origin_varname grad_name = origin_varname
aggregate = True aggregate = True
dense_ctx = core.CommContext( dense_ctx = core.CommContext(
...@@ -1058,7 +1058,7 @@ class CompileTimeStrategy: ...@@ -1058,7 +1058,7 @@ class CompileTimeStrategy:
blocks = [] blocks = []
for var in var_list: for var in var_list:
if not uniform: if not uniform:
var_numel = reduce(lambda x, y: x * y, var.shape) var_numel = reduce(lambda x, y: x * y, var.shape, 1)
split_count = 1 split_count = 1
...@@ -1077,7 +1077,7 @@ class CompileTimeStrategy: ...@@ -1077,7 +1077,7 @@ class CompileTimeStrategy:
if len(var.shape) >= 2: if len(var.shape) >= 2:
# align by dim1(width) # align by dim1(width)
dim1 = reduce(lambda x, y: x * y, var.shape[1:]) dim1 = reduce(lambda x, y: x * y, var.shape[1:], 1)
remains = block_size % dim1 remains = block_size % dim1
if remains != 0: if remains != 0:
block_size += dim1 - remains block_size += dim1 - remains
...@@ -1102,7 +1102,7 @@ class CompileTimeStrategy: ...@@ -1102,7 +1102,7 @@ class CompileTimeStrategy:
for i in range(remainder): for i in range(remainder):
dim0s[i] = dim0s[i] + 1 dim0s[i] = dim0s[i] + 1
dim1 = reduce(lambda x, y: x * y, var.shape[1:]) dim1 = reduce(lambda x, y: x * y, var.shape[1:], 1)
for block_id in range(len(dim0s)): for block_id in range(len(dim0s)):
numel = dim0s[block_id] * dim1 numel = dim0s[block_id] * dim1
......
...@@ -1484,7 +1484,7 @@ def get_communicate_var_info( ...@@ -1484,7 +1484,7 @@ def get_communicate_var_info(
# raise ValueError( # raise ValueError(
# "Variable {} not support heter training. its shape is {}". # "Variable {} not support heter training. its shape is {}".
# format(name, shape)) # format(name, shape))
recv_var_dim = -1 * reduce(lambda x, y: x * y, shape) recv_var_dim = -1 * reduce(lambda x, y: x * y, shape, 1)
input_var_reshape_dim.append(recv_var_dim) input_var_reshape_dim.append(recv_var_dim)
input_var_reshape_name.append(f"{name}.input_reshape@Heter") input_var_reshape_name.append(f"{name}.input_reshape@Heter")
...@@ -1497,7 +1497,7 @@ def get_communicate_var_info( ...@@ -1497,7 +1497,7 @@ def get_communicate_var_info(
# # raise ValueError( # # raise ValueError(
# # "Variable {} not support heter training. its shape is {}". # # "Variable {} not support heter training. its shape is {}".
# # format(var_name, shape)) # # format(var_name, shape))
# send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape) # send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape, 1)
# output_var_reshape_dim.append(send_reshape_dim) # output_var_reshape_dim.append(send_reshape_dim)
# output_var_reshape_name.append("{}.output_reshape@Heter".format( # output_var_reshape_name.append("{}.output_reshape@Heter".format(
# var_name)) # var_name))
......
...@@ -65,7 +65,7 @@ class VarStruct: ...@@ -65,7 +65,7 @@ class VarStruct:
self.lod_level = lod_level self.lod_level = lod_level
self.persistable = persistable self.persistable = persistable
self.m_size = 1 self.m_size = 1
self.m_size = reduce(lambda x, y: x * y, shape) self.m_size = reduce(lambda x, y: x * y, shape, 1)
self.m_size *= dtype_to_size[dtype] self.m_size *= dtype_to_size[dtype]
def __str__(self): def __str__(self):
......
...@@ -546,7 +546,7 @@ def local_response_norm( ...@@ -546,7 +546,7 @@ def local_response_norm(
from functools import reduce from functools import reduce
sum_sizes = reduce(lambda x, y: x * y, sizes[1:]) sum_sizes = reduce(lambda x, y: x * y, sizes[1:], 1)
div = paddle.unsqueeze(paddle.multiply(x, x), axis=1) div = paddle.unsqueeze(paddle.multiply(x, x), axis=1)
if not channel_last: if not channel_last:
......
...@@ -155,7 +155,7 @@ def vector_to_parameters(vec, parameters, name=None): ...@@ -155,7 +155,7 @@ def vector_to_parameters(vec, parameters, name=None):
for param in parameters: for param in parameters:
shape = param.shape shape = param.shape
origin_shapes.append(shape) origin_shapes.append(shape)
numel = reduce(lambda x, y: x * y, shape) numel = reduce(lambda x, y: x * y, shape, 1)
sections.append(numel) sections.append(numel)
if len(sections) == 1: if len(sections) == 1:
......
...@@ -3601,7 +3601,7 @@ def layer_norm( ...@@ -3601,7 +3601,7 @@ def layer_norm(
# create intput and parameters # create intput and parameters
inputs = {'X': input} inputs = {'X': input}
input_shape = input.shape input_shape = input.shape
param_shape = [reduce(lambda x, y: x * y, input_shape[begin_norm_axis:])] param_shape = [reduce(lambda x, y: x * y, input_shape[begin_norm_axis:], 1)]
if scale: if scale:
assert ( assert (
param_attr is not False param_attr is not False
......
...@@ -46,13 +46,13 @@ class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest): ...@@ -46,13 +46,13 @@ class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest):
def generate_input2(attrs): def generate_input2(attrs):
shape_x = [attrs[3]['batch_size'], attrs[3]['channel'], self.num] shape_x = [attrs[3]['batch_size'], attrs[3]['channel'], self.num]
input_volume = reduce(lambda x, y: x * y, shape_x) input_volume = reduce(lambda x, y: x * y, shape_x, 1)
matmul_shape = list(attrs[0]['shape']) matmul_shape = list(attrs[0]['shape'])
if 0 in matmul_shape: if 0 in matmul_shape:
for i in range(len(matmul_shape)): for i in range(len(matmul_shape)):
if matmul_shape[i] == 0: if matmul_shape[i] == 0:
matmul_shape[i] = shape_x[i] matmul_shape[i] = shape_x[i]
shape_volume = reduce(lambda x, y: x * y, matmul_shape) shape_volume = reduce(lambda x, y: x * y, matmul_shape, 1)
if -1 in matmul_shape: if -1 in matmul_shape:
for i in range(len(matmul_shape)): for i in range(len(matmul_shape)):
......
...@@ -1252,7 +1252,9 @@ class XPUTestSetValueOp(XPUOpTestWrapper): ...@@ -1252,7 +1252,9 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
paddle.enable_static() paddle.enable_static()
to_string = lambda x, i: x + '_' + str(i) to_string = lambda x, i: x + '_' + str(i)
numel = lambda input_shape: reduce(lambda x, y: x * y, input_shape) numel = lambda input_shape: reduce(
lambda x, y: x * y, input_shape, 1
)
def op1(x): def op1(x):
value = paddle.tensor.fill_constant([1], "float32", 1) value = paddle.tensor.fill_constant([1], "float32", 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册