From 6f53d3b2361c5b669b34dbc88269997ca6944512 Mon Sep 17 00:00:00 2001 From: Difer <707065510@qq.com> Date: Mon, 31 Jul 2023 16:10:05 +0800 Subject: [PATCH] reaplce fill_constant_batch_size_like (#55522) * simple reaplce * for debug * fix bugs * fix some bugs * del fill_constant_batch_size_like --- .../auto_parallel/static/cost/comp_op_cost.py | 8 ++ python/paddle/distribution/bernoulli.py | 14 +- python/paddle/distribution/categorical.py | 4 +- python/paddle/distribution/distribution.py | 8 +- python/paddle/distribution/normal.py | 19 +-- python/paddle/distribution/uniform.py | 12 +- python/paddle/fluid/layers/__init__.py | 3 - .../fluid/layers/learning_rate_scheduler.py | 3 +- python/paddle/fluid/layers/tensor.py | 125 ------------------ python/paddle/fluid/optimizer.py | 1 - python/paddle/nn/layer/rnn.py | 44 ++++-- python/paddle/nn/layer/transformer.py | 17 +-- test/auto_parallel/test_base_cost.py | 7 +- test/auto_parallel/test_dist_op_cost.py | 125 +++++++++--------- test/auto_parallel/test_while_op_partition.py | 17 ++- test/legacy_test/auto_parallel_gpt_model.py | 17 +-- test/legacy_test/dist_fleet_simnet_bow.py | 12 +- .../test_auto_parallel_completion_gpt.py | 17 +-- .../test_auto_parallel_partitioner_gpt.py | 17 +-- test/legacy_test/test_dist_fleet_minimize.py | 12 +- test/legacy_test/test_dist_fleet_ps.py | 12 +- test/legacy_test/test_dist_fleet_ps11.py | 12 +- test/legacy_test/test_dist_fleet_ps12.py | 12 +- test/legacy_test/test_dist_fleet_ps13.py | 12 +- test/legacy_test/test_dist_fleet_ps2.py | 12 +- test/legacy_test/test_dist_fleet_ps3.py | 12 +- test/legacy_test/test_dist_fleet_ps4.py | 12 +- test/legacy_test/test_dist_fleet_ps5.py | 12 +- test/legacy_test/test_dist_fleet_ps6.py | 12 +- test/legacy_test/test_dist_fleet_spmt.py | 12 +- test/legacy_test/test_dist_transpiler.py | 14 +- .../test_fill_constant_batch_size_like.py | 124 ----------------- test/legacy_test/test_layers.py | 10 -- 33 files changed, 260 insertions(+), 490 deletions(-) delete mode 100644 python/paddle/fluid/layers/tensor.py delete mode 100644 test/legacy_test/test_fill_constant_batch_size_like.py diff --git a/python/paddle/distributed/auto_parallel/static/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/static/cost/comp_op_cost.py index ea6d2ef571c..1039a7b2305 100644 --- a/python/paddle/distributed/auto_parallel/static/cost/comp_op_cost.py +++ b/python/paddle/distributed/auto_parallel/static/cost/comp_op_cost.py @@ -471,6 +471,14 @@ class ScaleOpCost(CompOpCost): super().__init__(op=op, op_desc=op_desc, cluster=cluster) +@register_op_cost +class ShapeOpCost(CompOpCost): + OP_TYPE = "shape" + + def __init__(self, op=None, op_desc=None, cluster=None): + super().__init__(op=op, op_desc=op_desc, cluster=cluster) + + @register_op_cost class SliceOpCost(CompOpCost): OP_TYPE = "slice" diff --git a/python/paddle/distribution/bernoulli.py b/python/paddle/distribution/bernoulli.py index 9ae721dad4e..1b365bbcd31 100644 --- a/python/paddle/distribution/bernoulli.py +++ b/python/paddle/distribution/bernoulli.py @@ -18,7 +18,7 @@ import numpy as np import paddle from paddle.distribution import exponential_family from paddle.fluid.data_feeder import check_type, convert_dtype -from paddle.fluid.layers import tensor +from paddle.fluid.framework import Variable from paddle.framework import in_dynamic_mode from paddle.nn.functional import ( binary_cross_entropy_with_logits, @@ -97,7 +97,7 @@ class Bernoulli(exponential_family.ExponentialFamily): check_type( probs, 'probs', - (float, tensor.Variable), + (float, Variable), self.name, ) @@ -180,7 +180,7 @@ class Bernoulli(exponential_family.ExponentialFamily): check_type( shape, 'shape', - (np.ndarray, tensor.Variable, list, tuple), + (np.ndarray, Variable, list, tuple), name, ) @@ -259,7 +259,7 @@ class Bernoulli(exponential_family.ExponentialFamily): check_type( shape, 'shape', - (np.ndarray, tensor.Variable, list, tuple), + (np.ndarray, Variable, list, tuple), name, ) check_type( @@ -318,7 +318,7 @@ class Bernoulli(exponential_family.ExponentialFamily): """ name = self.name + '_cdf' if not in_dynamic_mode(): - check_type(value, 'value', tensor.Variable, name) + check_type(value, 'value', Variable, name) value = self._check_values_dtype_in_probs(self.probs, value) probs, value = paddle.broadcast_tensors([self.probs, value]) @@ -356,7 +356,7 @@ class Bernoulli(exponential_family.ExponentialFamily): """ name = self.name + '_log_prob' if not in_dynamic_mode(): - check_type(value, 'value', tensor.Variable, name) + check_type(value, 'value', Variable, name) value = self._check_values_dtype_in_probs(self.probs, value) logits, value = paddle.broadcast_tensors([self.logits, value]) @@ -395,7 +395,7 @@ class Bernoulli(exponential_family.ExponentialFamily): """ name = self.name + '_prob' if not in_dynamic_mode(): - check_type(value, 'value', tensor.Variable, name) + check_type(value, 'value', Variable, name) return self.log_prob(value).exp(name=name) diff --git a/python/paddle/distribution/categorical.py b/python/paddle/distribution/categorical.py index 3eea67157f7..1af187c2cfe 100644 --- a/python/paddle/distribution/categorical.py +++ b/python/paddle/distribution/categorical.py @@ -17,7 +17,7 @@ import numpy as np import paddle from paddle.distribution import distribution from paddle.fluid.data_feeder import check_type, convert_dtype -from paddle.fluid.layers import tensor +from paddle.fluid.framework import Variable from paddle.framework import in_dynamic_mode from paddle.tensor import multinomial @@ -100,7 +100,7 @@ class Categorical(distribution.Distribution): check_type( logits, 'logits', - (np.ndarray, tensor.Variable, list, tuple), + (np.ndarray, Variable, list, tuple), 'Categorical', ) diff --git a/python/paddle/distribution/distribution.py b/python/paddle/distribution/distribution.py index 023bd53d7b4..68d468accee 100644 --- a/python/paddle/distribution/distribution.py +++ b/python/paddle/distribution/distribution.py @@ -26,7 +26,7 @@ import numpy as np import paddle from paddle import _C_ops from paddle.fluid.data_feeder import check_variable_and_dtype, convert_dtype -from paddle.fluid.layers import tensor +from paddle.fluid.framework import Variable from paddle.framework import in_dynamic_mode @@ -150,7 +150,7 @@ class Distribution: is_variable = False is_number = False for arg in args: - if isinstance(arg, tensor.Variable): + if isinstance(arg, Variable): is_variable = True else: is_number = True @@ -176,9 +176,7 @@ class Distribution: tmp = 0.0 for arg in args: - if not isinstance( - arg, (float, list, tuple, np.ndarray, tensor.Variable) - ): + if not isinstance(arg, (float, list, tuple, np.ndarray, Variable)): raise TypeError( "Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {}".format( type(arg) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 7ba987819a3..07b1b810d9b 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -20,7 +20,7 @@ import numpy as np import paddle from paddle.distribution import distribution from paddle.fluid.data_feeder import check_type, convert_dtype -from paddle.fluid.layers import tensor +from paddle.fluid.framework import Variable from paddle.framework import in_dynamic_mode from paddle.tensor import random @@ -91,13 +91,13 @@ class Normal(distribution.Distribution): check_type( loc, 'loc', - (int, float, np.ndarray, tensor.Variable, list, tuple), + (int, float, np.ndarray, Variable, list, tuple), 'Normal', ) check_type( scale, 'scale', - (int, float, np.ndarray, tensor.Variable, list, tuple), + (int, float, np.ndarray, Variable, list, tuple), 'Normal', ) @@ -174,9 +174,9 @@ class Normal(distribution.Distribution): name = self.name + '_sample' if -1 in batch_shape: output_shape = shape + batch_shape - zero_tmp = tensor.fill_constant_batch_size_like( - self.loc + self.scale, batch_shape + shape, self.dtype, 0.0 - ) + fill_shape = list(batch_shape + shape) + fill_shape[0] = paddle.shape(self.loc + self.scale)[0].item() + zero_tmp = paddle.full(fill_shape, 0.0, self.dtype) zero_tmp_reshape = paddle.reshape(zero_tmp, output_shape) zero_tmp_shape = paddle.shape(zero_tmp_reshape) @@ -234,9 +234,10 @@ class Normal(distribution.Distribution): name = self.name + '_entropy' batch_shape = list((self.loc + self.scale).shape) if -1 in batch_shape: - zero_tmp = tensor.fill_constant_batch_size_like( - self.loc + self.scale, batch_shape, self.dtype, 0.0 - ) + fill_shape = list(batch_shape) + fill_shape[0] = paddle.shape(self.loc + self.scale)[0].item() + fill_dtype = (self.loc + self.scale).dtype + zero_tmp = paddle.full(fill_shape, 0.0, fill_dtype) else: zero_tmp = paddle.full(batch_shape, 0.0, self.dtype) return paddle.add( diff --git a/python/paddle/distribution/uniform.py b/python/paddle/distribution/uniform.py index 5619258efcf..dbd27fd1472 100644 --- a/python/paddle/distribution/uniform.py +++ b/python/paddle/distribution/uniform.py @@ -18,7 +18,7 @@ import paddle from paddle import _C_ops from paddle.distribution import distribution from paddle.fluid.data_feeder import check_type, convert_dtype -from paddle.fluid.layers import tensor +from paddle.fluid.framework import Variable from paddle.framework import in_dynamic_mode from paddle.tensor import random @@ -105,13 +105,13 @@ class Uniform(distribution.Distribution): check_type( low, 'low', - (int, float, np.ndarray, tensor.Variable, list, tuple), + (int, float, np.ndarray, Variable, list, tuple), 'Uniform', ) check_type( high, 'high', - (int, float, np.ndarray, tensor.Variable, list, tuple), + (int, float, np.ndarray, Variable, list, tuple), 'Uniform', ) @@ -169,9 +169,9 @@ class Uniform(distribution.Distribution): batch_shape = list((self.low + self.high).shape) if -1 in batch_shape: output_shape = shape + batch_shape - zero_tmp = tensor.fill_constant_batch_size_like( - self.low + self.high, batch_shape + shape, self.dtype, 0.0 - ) + fill_shape = list(batch_shape + shape) + fill_shape[0] = paddle.shape(self.low + self.high)[0].item() + zero_tmp = paddle.full(fill_shape, 0.0, self.dtype) uniform_random_tmp = random.uniform_random_batch_size_like( zero_tmp, zero_tmp.shape, diff --git a/python/paddle/fluid/layers/__init__.py b/python/paddle/fluid/layers/__init__.py index b91d7de093c..c5eb01ff763 100644 --- a/python/paddle/fluid/layers/__init__.py +++ b/python/paddle/fluid/layers/__init__.py @@ -16,8 +16,6 @@ from . import nn from .nn import * from . import io from .io import * -from . import tensor -from .tensor import * from . import math_op_patch from .math_op_patch import * from .learning_rate_scheduler import * @@ -27,5 +25,4 @@ from ..layer_helper import LayerHelper __all__ = [] __all__ += nn.__all__ __all__ += io.__all__ -__all__ += tensor.__all__ __all__ += learning_rate_scheduler.__all__ diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 050cc774ab7..59f25c63b74 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -25,7 +25,6 @@ import numbers import paddle from . import nn -from . import tensor from ..framework import ( default_main_program, Parameter, @@ -488,7 +487,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs): learning_rate = base_lr, step_each_epoch=10000, epochs=120) """ check_type( - learning_rate, 'learning_rate', (float, tensor.Variable), 'cosine_decay' + learning_rate, 'learning_rate', (float, Variable), 'cosine_decay' ) with default_main_program()._lr_schedule_guard(): diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py deleted file mode 100644 index 06cfbf1cecb..00000000000 --- a/python/paddle/fluid/layers/tensor.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unlessf required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -import numpy -import warnings - -from ..layer_helper import LayerHelper -from ..framework import ( - _current_expected_place, - convert_np_dtype_to_dtype_, - _create_tensor, - in_dygraph_mode, -) -from ..framework import Variable -from ..core import VarDesc -from .. import core -from .layer_function_generator import templatedoc -from ..data_feeder import ( - check_variable_and_dtype, - check_type, - check_dtype, - convert_dtype, -) -from paddle.utils import deprecated - -from paddle import _C_ops, _legacy_C_ops - -__all__ = [ - 'fill_constant_batch_size_like', -] - - -@deprecated(since='1.8.0', update_to="paddle.fluid.layers.fill_constant") -@templatedoc() -def fill_constant_batch_size_like( - input, - shape, - dtype, - value, - input_dim_idx=0, - output_dim_idx=0, - force_cpu=False, -): - """ - This OP creates a Tesnor according the shape and dtype, and initializes the - Tensor with the constants provided in ``value``. When the input is LoDTensor - and the input_dim_idx is 0, the output_dim_idx dimension is set to the value - of the batch_size input by the input, the Stop_gradient attribute of the created - Tensor is False by default. - - Args: - input(Variable): Tensor which data type is float32, float64, int32 and int64. - shape(list): The shape of Tensor to be created, Tensor's shape may be changed - according the input. - dtype(np.dtype|core.VarDesc.VarType|str): The data type of created Tensor which - can be float32, float64, int32, int64. - value(float|int): The constant value used to initialize the Tensor to be created. - input_dim_idx(int): When the value is 0 and the input is LoDTensor, the output_dim_idx - dimension of the created Tensor is set to the batch_size value of input. - The default value is 0. - output_dim_idx(int): Used to specify which dimension of Tensor is created to be set - the value of batch_size of input Tensor. The default value is 0. - force_cpu(bool): data should be on CPU if it's true, default value is False. - - Returns: - Variable: Tensor which will be created according to dtype. - - Examples: - - .. code-block:: python - - import paddle - import paddle.fluid as fluid - like = paddle.full(shape=[1,2], fill_value=10, dtype='int64') #like=[[10, 10]] - data = fluid.layers.fill_constant_batch_size_like( - input=like, shape=[1], value=0, dtype='int64') #like=[[10, 10]] data=[0] - - """ - if in_dygraph_mode(): - if not isinstance(dtype, core.VarDesc.VarType): - dtype = convert_np_dtype_to_dtype_(dtype) - - place = _current_expected_place() - if force_cpu: - place = core.CPUPlace() - out = _C_ops.full_batch_size_like( - input, shape, dtype, value, input_dim_idx, output_dim_idx, place - ) - out.stop_gradient = True - return out - else: - helper = LayerHelper("fill_constant_batch_size_like", **locals()) - out = helper.create_variable_for_type_inference(dtype=dtype) - attrs = { - 'shape': shape, - 'dtype': out.dtype, - 'value': float(value), - 'input_dim_idx': input_dim_idx, - 'output_dim_idx': output_dim_idx, - 'force_cpu': force_cpu, - } - if convert_dtype(dtype) in ['int64', 'int32']: - attrs['str_value'] = str(int(value)) - else: - attrs['str_value'] = str(float(value)) - helper.append_op( - type='fill_constant_batch_size_like', - inputs={'Input': input}, - outputs={'Out': [out]}, - attrs=attrs, - ) - out.stop_gradient = True - return out diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index b46f3da08cb..3f27d310624 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -48,7 +48,6 @@ from .dygraph.learning_rate_scheduler import ( _LearningRateEpochDecay, ) from paddle.fluid import core -from paddle.fluid.layers import tensor from functools import reduce from functools import cmp_to_key from .wrapped_decorator import signature_safe_contextmanager diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 3ef01f836fc..991df623d96 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -600,7 +600,9 @@ class RNNCellBase(Layer): class Shape: def __init__(self, shape): - self.shape = shape if shape[0] == -1 else ([-1] + list(shape)) + self.shape = ( + list(shape) if shape[0] == -1 else ([-1] + list(shape)) + ) # nested structure of shapes states_shapes = self.state_shape if shape is None else shape @@ -621,16 +623,35 @@ class RNNCellBase(Layer): states_dtypes = paddle.utils.map_structure( lambda shape: dtype, states_shapes ) + fill_shapes = states_shapes + if batch_ref.shape[batch_dim_idx] > 0: + if isinstance(fill_shapes, list): + for s in fill_shapes[0]: + s.shape[0] = batch_ref.shape[batch_dim_idx] + elif isinstance(fill_shapes, tuple): + for s in fill_shapes: + s.shape[0] = batch_ref.shape[batch_dim_idx] + else: + fill_shapes.shape[0] = batch_ref.shape[batch_dim_idx] + else: + if isinstance(fill_shapes, list): + for s in fill_shapes[0]: + s.shape[0] = paddle.shape(batch_ref)[batch_dim_idx].item() + elif isinstance(fill_shapes, tuple): + for s in fill_shapes: + s.shape[0] = paddle.shape(batch_ref)[batch_dim_idx].item() + else: + fill_shapes.shape[0] = paddle.shape(batch_ref)[ + batch_dim_idx + ].item() init_states = paddle.utils.map_structure( - lambda shape, dtype: paddle.fluid.layers.fill_constant_batch_size_like( - input=batch_ref, + lambda shape, dtype: paddle.full( shape=shape.shape, + fill_value=init_value, dtype=dtype, - value=init_value, - input_dim_idx=batch_dim_idx, ), - states_shapes, + fill_shapes, states_dtypes, ) return init_states @@ -1534,7 +1555,6 @@ class RNNBase(LayerList): 'Reserve': reserve, 'DropoutState': self._dropout_state, } - self._helper.append_op( type="rnn", inputs=inputs, outputs=outputs, attrs=attrs ) @@ -1555,11 +1575,15 @@ class RNNBase(LayerList): -1, self.hidden_size, ) + + fill_shape = list(state_shape) + if inputs.shape[batch_index] > 0: + fill_shape[1] = inputs.shape[batch_index] + else: + fill_shape[1] = paddle.shape(inputs)[batch_index].item() initial_states = tuple( [ - paddle.fluid.layers.fill_constant_batch_size_like( - inputs, state_shape, dtype, 0, batch_index, 1 - ) + paddle.full(shape=fill_shape, fill_value=0, dtype=dtype) for _ in range(self.state_components) ] ) diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index cab257f2e0e..335b47d2599 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -23,7 +23,6 @@ import paddle from paddle.fluid.data_feeder import convert_dtype from ... import tensor -from ...fluid import layers from ...framework import ParamAttr from .. import functional as F from .common import Dropout, Linear @@ -342,18 +341,10 @@ class MultiHeadAttention(Layer): k, v = self.compute_kv(key, value) return self.StaticCache(k, v) elif value is None: # incremental_state - k = layers.fill_constant_batch_size_like( - input=key, - shape=[-1, self.num_heads, 0, self.head_dim], - dtype=key.dtype, - value=0, - ) - v = layers.fill_constant_batch_size_like( - input=key, - shape=[-1, self.num_heads, 0, self.head_dim], - dtype=key.dtype, - value=0, - ) + fill_shape = [-1, self.num_heads, 0, self.head_dim] + fill_shape[0] = paddle.shape(key)[0].item() + k = paddle.full(fill_shape, 0, key.dtype) + v = paddle.full(fill_shape, 0, key.dtype) return self.Cache(k, v) else: # incremental_state with initial value, mainly for usage like UniLM diff --git a/test/auto_parallel/test_base_cost.py b/test/auto_parallel/test_base_cost.py index c9e3e64c6a8..62c695b9e1d 100644 --- a/test/auto_parallel/test_base_cost.py +++ b/test/auto_parallel/test_base_cost.py @@ -101,10 +101,9 @@ def mlp_forward(train_program, start_program): label = static.data( name="label", shape=[batch_size, 1], dtype='float32' ) - - fill_constant_out = paddle.fluid.layers.fill_constant_batch_size_like( - input=input, shape=[batch_size], value=1, dtype="int32" - ) + fill_shape = [batch_size] + fill_shape[0] = input.shape[0] + fill_constant_out = paddle.full(fill_shape, 1, dtype="int32") embedding = paddle.nn.Embedding(10, hidden_size, sparse=True) embedding_out = embedding(fill_constant_out) diff --git a/test/auto_parallel/test_dist_op_cost.py b/test/auto_parallel/test_dist_op_cost.py index 4d7cca7e5b3..b5ac2249873 100644 --- a/test/auto_parallel/test_dist_op_cost.py +++ b/test/auto_parallel/test_dist_op_cost.py @@ -75,9 +75,9 @@ class TestDistOpCost(unittest.TestCase): auto.shard_tensor( x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None] ) - tmp = paddle.fluid.layers.fill_constant_batch_size_like( - input=x, shape=[2, 8], value=1, dtype='float32' - ) + fill_shape = [2, 8] + fill_shape[0] = x.shape[0] + tmp = paddle.full(fill_shape, fill_value=1, dtype='float32') weight_attr = paddle.ParamAttr() linear = paddle.nn.Linear(8, 1, weight_attr=weight_attr) linear_out = linear(x) @@ -97,6 +97,8 @@ class TestDistOpCost(unittest.TestCase): op.type != "matmul_v2" and op.type != "matmul_v2_grad" and op.type != "sgd" + and op.type != "shape" + and op.type != "slice" ): dist_op = dist_context.get_dist_op_for_program(op) op_dist_attr = dist_op.dist_attr @@ -137,9 +139,9 @@ class TestDistOpCost(unittest.TestCase): ["x", None], ) # embedding - tmp = paddle.fluid.layers.fill_constant_batch_size_like( - input=x, shape=[4], value=1, dtype='int32' - ) + fill_shape = [4] + fill_shape[0] = x.shape[0] + tmp = paddle.full(shape=fill_shape, fill_value=1, dtype='int32') embedding = paddle.nn.Embedding(10, 8) out = embedding(tmp) # row parallel embedding @@ -206,23 +208,24 @@ class TestDistOpCost(unittest.TestCase): cluster = Cluster() cluster.gen_default_config_cluster(device_count=2) for idx, op in enumerate(ops): - dist_op = dist_context.get_dist_op_for_program(op) - op_dist_attr = dist_op.dist_attr - processes = op_dist_attr.process_mesh.process_ids - if is_elementwise_op(op.type): - container = get_distributed_operator_impl_container( - "elementwise" - ) - else: - container = get_distributed_operator_impl_container( - op_dist_attr.impl_type - ) + if op.type != "shape" and op.type != "slice": + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + processes = op_dist_attr.process_mesh.process_ids + if is_elementwise_op(op.type): + container = get_distributed_operator_impl_container( + "elementwise" + ) + else: + container = get_distributed_operator_impl_container( + op_dist_attr.impl_type + ) - dist_impl = container.impls[op_dist_attr.impl_idx] - dist_op_cost = dist_impl.calc_cost( - op.attr('op_role'), dist_op, dist_context, cluster - ) - self.assertTrue(dist_op_cost) + dist_impl = container.impls[op_dist_attr.impl_idx] + dist_op_cost = dist_impl.calc_cost( + op.attr('op_role'), dist_op, dist_context, cluster + ) + self.assertTrue(dist_op_cost) def test_dist_op_cost_part3(self): def make_program(): @@ -245,9 +248,9 @@ class TestDistOpCost(unittest.TestCase): ["x", None], ) # embedding - tmp = paddle.fluid.layers.fill_constant_batch_size_like( - input=x, shape=[4], value=1, dtype='int32' - ) + fill_shape = [4] + fill_shape[0] = x.shape[0] + tmp = paddle.full(shape=fill_shape, fill_value=1, dtype='int32') embedding = paddle.nn.Embedding(10, 8) out = embedding(tmp) # row parallel embedding @@ -315,23 +318,24 @@ class TestDistOpCost(unittest.TestCase): cluster = Cluster() cluster.gen_default_config_cluster(device_count=2) for idx, op in enumerate(ops): - dist_op = dist_context.get_dist_op_for_program(op) - op_dist_attr = dist_op.dist_attr - processes = op_dist_attr.process_mesh.process_ids - if is_elementwise_op(op.type): - container = get_distributed_operator_impl_container( - "elementwise" - ) - else: - container = get_distributed_operator_impl_container( - op_dist_attr.impl_type - ) + if op.type != "shape" and op.type != "slice": + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + processes = op_dist_attr.process_mesh.process_ids + if is_elementwise_op(op.type): + container = get_distributed_operator_impl_container( + "elementwise" + ) + else: + container = get_distributed_operator_impl_container( + op_dist_attr.impl_type + ) - dist_impl = container.impls[op_dist_attr.impl_idx] - dist_op_cost = dist_impl.calc_cost( - op.attr('op_role'), dist_op, dist_context, cluster - ) - self.assertTrue(dist_op_cost) + dist_impl = container.impls[op_dist_attr.impl_idx] + dist_op_cost = dist_impl.calc_cost( + op.attr('op_role'), dist_op, dist_context, cluster + ) + self.assertTrue(dist_op_cost) def test_dist_op_cost_part4(self): def make_program(): @@ -353,9 +357,9 @@ class TestDistOpCost(unittest.TestCase): ["x", None], ) # embedding - tmp = paddle.fluid.layers.fill_constant_batch_size_like( - input=x, shape=[4], value=1, dtype='int32' - ) + fill_shape = [4] + fill_shape[0] = x.shape[0] + tmp = paddle.full(shape=fill_shape, fill_value=1, dtype='int32') embedding = paddle.nn.Embedding(10, 8) out = embedding(tmp) # row parallel embedding @@ -423,23 +427,24 @@ class TestDistOpCost(unittest.TestCase): cluster = Cluster() cluster.gen_default_config_cluster(device_count=2) for idx, op in enumerate(ops): - dist_op = dist_context.get_dist_op_for_program(op) - op_dist_attr = dist_op.dist_attr - processes = op_dist_attr.process_mesh.process_ids - if is_elementwise_op(op.type): - container = get_distributed_operator_impl_container( - "elementwise" - ) - else: - container = get_distributed_operator_impl_container( - op_dist_attr.impl_type - ) + if op.type != "shape" and op.type != "slice": + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + processes = op_dist_attr.process_mesh.process_ids + if is_elementwise_op(op.type): + container = get_distributed_operator_impl_container( + "elementwise" + ) + else: + container = get_distributed_operator_impl_container( + op_dist_attr.impl_type + ) - dist_impl = container.impls[op_dist_attr.impl_idx] - dist_op_cost = dist_impl.calc_cost( - op.attr('op_role'), dist_op, dist_context, cluster - ) - self.assertTrue(dist_op_cost) + dist_impl = container.impls[op_dist_attr.impl_idx] + dist_op_cost = dist_impl.calc_cost( + op.attr('op_role'), dist_op, dist_context, cluster + ) + self.assertTrue(dist_op_cost) if __name__ == "__main__": diff --git a/test/auto_parallel/test_while_op_partition.py b/test/auto_parallel/test_while_op_partition.py index cbab4cf981f..fd8edc6eba7 100644 --- a/test/auto_parallel/test_while_op_partition.py +++ b/test/auto_parallel/test_while_op_partition.py @@ -145,8 +145,18 @@ def get_program(): auto.shard_tensor(label, _g_process_mesh, [None, None, None]) # fill constant bsz like - tmp = paddle.fluid.layers.fill_constant_batch_size_like( - input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0 + block = train_program.current_block() + fill_shape = [-1, 16, 0, 48] + tmp = block.create_var(name='tmp', dtype='float32') + block.append_op( + type='fill_constant_batch_size_like', + outputs={'Out': [tmp]}, + inputs={'Input': [input]}, + attrs={ + 'shape': fill_shape, + 'value': 0, + }, + stop_gradient=True, ) auto.shard_tensor(tmp, _g_process_mesh, [None, 'x', None, None]) @@ -369,7 +379,6 @@ class TestMLP(unittest.TestCase): train_program, start_program, dist_context ) dist_context.block_state.parse_forward_blocks(train_program) - dist_main_prog, dist_startup_prog = partition( train_program, start_program, dist_context ) @@ -388,8 +397,8 @@ class TestMLP(unittest.TestCase): self.assertTrue("c_allreduce_sum" in sub_block_ops) # test fill_constant_batch_size_like - self.assertIsNotNone(fill_op) + ref_shape = [-1, 8, 0, 48] shape = fill_op.attr("shape") self.assertTrue(ref_shape == shape) diff --git a/test/legacy_test/auto_parallel_gpt_model.py b/test/legacy_test/auto_parallel_gpt_model.py index 28e63db4bf1..1be27f9bc80 100644 --- a/test/legacy_test/auto_parallel_gpt_model.py +++ b/test/legacy_test/auto_parallel_gpt_model.py @@ -18,7 +18,6 @@ import paddle import paddle.nn.functional as F from paddle import nn, tensor from paddle.distributed.fleet import auto -from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list paddle.enable_static() @@ -212,18 +211,10 @@ class MultiHeadAttention(nn.Layer): k, v = self.compute_kv(key, value) return self.StaticCache(k, v) elif value is None: # incremental_state - k = layers.fill_constant_batch_size_like( - input=key, - shape=[-1, self.num_heads, 0, self.head_dim], - dtype=key.dtype, - value=0, - ) - v = layers.fill_constant_batch_size_like( - input=key, - shape=[-1, self.num_heads, 0, self.head_dim], - dtype=key.dtype, - value=0, - ) + fill_shape = [-1, self.num_heads, 0, self.head_dim] + fill_shape[0] = paddle.shape(key)[0].item() + k = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype) + v = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype) return self.Cache(k, v) else: # incremental_state with initial value, mainly for usage like UniLM diff --git a/test/legacy_test/dist_fleet_simnet_bow.py b/test/legacy_test/dist_fleet_simnet_bow.py index 8ee220682e4..5885f395694 100644 --- a/test/legacy_test/dist_fleet_simnet_bow.py +++ b/test/legacy_test/dist_fleet_simnet_bow.py @@ -68,17 +68,17 @@ def get_acc(cos_q_nt, cos_q_pt, batch_size): def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() + loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=margin, dtype='float32'), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_auto_parallel_completion_gpt.py b/test/legacy_test/test_auto_parallel_completion_gpt.py index cc09ac989e1..cd00ae2c736 100644 --- a/test/legacy_test/test_auto_parallel_completion_gpt.py +++ b/test/legacy_test/test_auto_parallel_completion_gpt.py @@ -23,7 +23,6 @@ from paddle.distributed.auto_parallel.static.dist_context import ( DistributedContext, ) from paddle.distributed.fleet import auto -from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list paddle.enable_static() @@ -172,18 +171,10 @@ class MultiHeadAttention(nn.Layer): k, v = self.compute_kv(key, value) return self.StaticCache(k, v) elif value is None: # incremental_state - k = layers.fill_constant_batch_size_like( - input=key, - shape=[-1, self.num_heads, 0, self.head_dim], - dtype=key.dtype, - value=0, - ) - v = layers.fill_constant_batch_size_like( - input=key, - shape=[-1, self.num_heads, 0, self.head_dim], - dtype=key.dtype, - value=0, - ) + fill_shape = [-1, self.num_heads, 0, self.head_dim] + fill_shape[0] = paddle.shape(key)[0].item() + k = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype) + v = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype) return self.Cache(k, v) else: # incremental_state with initial value, mainly for usage like UniLM diff --git a/test/legacy_test/test_auto_parallel_partitioner_gpt.py b/test/legacy_test/test_auto_parallel_partitioner_gpt.py index 66c0eb3ea74..0828cafa60b 100644 --- a/test/legacy_test/test_auto_parallel_partitioner_gpt.py +++ b/test/legacy_test/test_auto_parallel_partitioner_gpt.py @@ -28,7 +28,6 @@ from paddle.distributed.auto_parallel.static.process_group import ( ) from paddle.distributed.auto_parallel.static.utils import _get_comm_group from paddle.distributed.fleet import auto -from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list paddle.enable_static() @@ -218,18 +217,10 @@ class MultiHeadAttention(nn.Layer): k, v = self.compute_kv(key, value) return self.StaticCache(k, v) elif value is None: # incremental_state - k = layers.fill_constant_batch_size_like( - input=key, - shape=[-1, self.num_heads, 0, self.head_dim], - dtype=key.dtype, - value=0, - ) - v = layers.fill_constant_batch_size_like( - input=key, - shape=[-1, self.num_heads, 0, self.head_dim], - dtype=key.dtype, - value=0, - ) + fill_shape = [-1, self.num_heads, 0, self.head_dim] + fill_shape[0] = paddle.shape(key)[0].item() + k = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype) + v = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype) return self.Cache(k, v) else: # incremental_state with initial value, mainly for usage like UniLM diff --git a/test/legacy_test/test_dist_fleet_minimize.py b/test/legacy_test/test_dist_fleet_minimize.py index 3eb44988d88..59751a96041 100644 --- a/test/legacy_test/test_dist_fleet_minimize.py +++ b/test/legacy_test/test_dist_fleet_minimize.py @@ -49,17 +49,19 @@ class TestPSMinimize(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps.py b/test/legacy_test/test_dist_fleet_ps.py index 8266a02dfbb..eb423b3c341 100644 --- a/test/legacy_test/test_dist_fleet_ps.py +++ b/test/legacy_test/test_dist_fleet_ps.py @@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps11.py b/test/legacy_test/test_dist_fleet_ps11.py index 755636d0ab2..d5a4c64423f 100755 --- a/test/legacy_test/test_dist_fleet_ps11.py +++ b/test/legacy_test/test_dist_fleet_ps11.py @@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps12.py b/test/legacy_test/test_dist_fleet_ps12.py index 1b7b30780cb..dc0d0325a5a 100644 --- a/test/legacy_test/test_dist_fleet_ps12.py +++ b/test/legacy_test/test_dist_fleet_ps12.py @@ -52,17 +52,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps13.py b/test/legacy_test/test_dist_fleet_ps13.py index 3cb1dec9ae9..2fbdbeba47f 100644 --- a/test/legacy_test/test_dist_fleet_ps13.py +++ b/test/legacy_test/test_dist_fleet_ps13.py @@ -53,17 +53,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps2.py b/test/legacy_test/test_dist_fleet_ps2.py index c6bbaee3a20..f27e4172d12 100644 --- a/test/legacy_test/test_dist_fleet_ps2.py +++ b/test/legacy_test/test_dist_fleet_ps2.py @@ -52,17 +52,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps3.py b/test/legacy_test/test_dist_fleet_ps3.py index 15f0bc363db..9f1ff73b830 100644 --- a/test/legacy_test/test_dist_fleet_ps3.py +++ b/test/legacy_test/test_dist_fleet_ps3.py @@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps4.py b/test/legacy_test/test_dist_fleet_ps4.py index b3c8dedf3ee..3d401885815 100644 --- a/test/legacy_test/test_dist_fleet_ps4.py +++ b/test/legacy_test/test_dist_fleet_ps4.py @@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps5.py b/test/legacy_test/test_dist_fleet_ps5.py index 5eeab8dac74..efc70346ab1 100644 --- a/test/legacy_test/test_dist_fleet_ps5.py +++ b/test/legacy_test/test_dist_fleet_ps5.py @@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_ps6.py b/test/legacy_test/test_dist_fleet_ps6.py index f8eaafe3cc3..c4be4348c0c 100644 --- a/test/legacy_test/test_dist_fleet_ps6.py +++ b/test/legacy_test/test_dist_fleet_ps6.py @@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_fleet_spmt.py b/test/legacy_test/test_dist_fleet_spmt.py index 6d9d6cd86df..17e6c03693d 100644 --- a/test/legacy_test/test_dist_fleet_spmt.py +++ b/test/legacy_test/test_dist_fleet_spmt.py @@ -47,17 +47,19 @@ class TestSPMT(unittest.TestCase): return acc def get_loss(cos_q_pt, cos_q_nt): + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(cos_q_pt)[0].item() loss_op1 = paddle.subtract( - fluid.layers.fill_constant_batch_size_like( - input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' + paddle.full( + shape=fill_shape, fill_value=margin, dtype='float32' ), cos_q_pt, ) loss_op2 = paddle.add(loss_op1, cos_q_nt) + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(loss_op2)[0].item() loss_op3 = paddle.maximum( - fluid.layers.fill_constant_batch_size_like( - input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32' - ), + paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'), loss_op2, ) avg_cost = paddle.mean(loss_op3) diff --git a/test/legacy_test/test_dist_transpiler.py b/test/legacy_test/test_dist_transpiler.py index ed23ecd294e..73ca10308eb 100644 --- a/test/legacy_test/test_dist_transpiler.py +++ b/test/legacy_test/test_dist_transpiler.py @@ -422,11 +422,15 @@ class TestFakeInit(TranspilerTest): neg_matmul_re = paddle.reshape(neg_matmul, shape=[-1, neg_num]) neg_logits = paddle.add(neg_matmul_re, neg_emb_b_vec) # nce loss - label_ones = fluid.layers.fill_constant_batch_size_like( - true_logits, shape=[-1, 1], value=1.0, dtype='float32' - ) - label_zeros = fluid.layers.fill_constant_batch_size_like( - true_logits, shape=[-1, neg_num], value=0.0, dtype='float32' + fill_shape = [-1, 1] + fill_shape[0] = paddle.shape(true_logits)[0].item() + label_ones = paddle.full( + shape=fill_shape, fill_value=1.0, dtype='float32' + ) + fill_shape = [-1, neg_num] + fill_shape[0] = paddle.shape(true_logits)[0].item() + label_zeros = paddle.full( + shape=fill_shape, fill_value=0.0, dtype='float32' ) true_xent = paddle.nn.functional.binary_cross_entropy_with_logits( diff --git a/test/legacy_test/test_fill_constant_batch_size_like.py b/test/legacy_test/test_fill_constant_batch_size_like.py deleted file mode 100644 index bd077c87984..00000000000 --- a/test/legacy_test/test_fill_constant_batch_size_like.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from eager_op_test import OpTest, convert_float_to_uint16 - -import paddle -from paddle.fluid import core -from paddle.fluid.framework import convert_np_dtype_to_dtype_ - -paddle.enable_static() - - -def fill_constant_batch_size_like( - input, - shape, - value, - data_type, - input_dim_idx=0, - output_dim_idx=0, - force_cpu=False, -): - return paddle.fluid.layers.fill_constant_batch_size_like( - input, shape, data_type, value, input_dim_idx, output_dim_idx, force_cpu - ) - - -class TestFillConstantBatchSizeLike1(OpTest): - # test basic - def setUp(self): - self.op_type = "fill_constant_batch_size_like" - self.python_api = fill_constant_batch_size_like - self.init_dtype() - self.init_data() - - input = np.zeros(self.shape) - out = np.full_like(input, self.value, self.dtype) - - self.inputs = {'Input': input} - self.outputs = {'Out': out} - self.attrs = { - 'shape': self.shape, - 'dtype': convert_np_dtype_to_dtype_(self.dtype), - 'value': self.value, - 'input_dim_idx': self.input_dim_idx, - 'output_dim_idx': self.output_dim_idx, - 'force_cpu': self.force_cpu, - } - - def init_dtype(self): - self.dtype = np.float32 - - def init_data(self): - self.shape = [10, 10] - self.value = 100 - self.input_dim_idx = 0 - self.output_dim_idx = 0 - self.force_cpu = False - - def test_check_output(self): - self.check_output() - - -class TestFillConstantBatchSizeLikeFP16Op(TestFillConstantBatchSizeLike1): - def init_dtype(self): - self.dtype = np.float16 - - -@unittest.skipIf( - not core.is_compiled_with_cuda() or not core.supports_bfloat16(), - "core is not compiled with CUDA or place do not support bfloat16", -) -class TestFillConstantBatchSizeLikeBF16Op(OpTest): - # test bf16 - def setUp(self): - self.op_type = "fill_constant_batch_size_like" - self.python_api = fill_constant_batch_size_like - self.init_data() - - input = np.zeros(self.shape).astype("float32") - input_bf16 = convert_float_to_uint16(input) - out = np.full_like(input, self.value, np.float32) - out_bf16 = convert_float_to_uint16(out) - - self.inputs = {'Input': input_bf16} - self.outputs = {'Out': out_bf16} - self.attrs = { - 'shape': self.shape, - 'dtype': convert_np_dtype_to_dtype_(self.dtype), - 'value': self.value, - 'input_dim_idx': self.input_dim_idx, - 'output_dim_idx': self.output_dim_idx, - 'force_cpu': self.force_cpu, - } - - def init_data(self): - self.shape = [10, 10] - self.dtype = np.uint16 - self.value = 100 - self.input_dim_idx = 0 - self.output_dim_idx = 0 - self.force_cpu = False - - def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place) - - -if __name__ == "__main__": - paddle.enable_static() - unittest.main() diff --git a/test/legacy_test/test_layers.py b/test/legacy_test/test_layers.py index ded9e08da74..44986f0b122 100644 --- a/test/legacy_test/test_layers.py +++ b/test/legacy_test/test_layers.py @@ -2133,16 +2133,6 @@ class TestBook(LayerTest): ) return out - def test_fill_constant_batch_size_like(self): - with self.static_graph(): - like = paddle.tensor.fill_constant( - shape=[1, 200], value=10, dtype='int64' - ) - out = layers.fill_constant_batch_size_like( - input=like, shape=[2, 3300], value=1315454564656, dtype='int64' - ) - return out - def test_shuffle_batch(self): # TODO(minqiyang): dygraph do not support lod now with self.static_graph(): -- GitLab