未验证 提交 6f53d3b2 编写于 作者: D Difer 提交者: GitHub

reaplce fill_constant_batch_size_like (#55522)

* simple reaplce

* for debug

* fix bugs

* fix some bugs

* del fill_constant_batch_size_like
上级 2dbd47b2
...@@ -471,6 +471,14 @@ class ScaleOpCost(CompOpCost): ...@@ -471,6 +471,14 @@ class ScaleOpCost(CompOpCost):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
@register_op_cost
class ShapeOpCost(CompOpCost):
OP_TYPE = "shape"
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
@register_op_cost @register_op_cost
class SliceOpCost(CompOpCost): class SliceOpCost(CompOpCost):
OP_TYPE = "slice" OP_TYPE = "slice"
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import paddle import paddle
from paddle.distribution import exponential_family from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_type, convert_dtype from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.layers import tensor from paddle.fluid.framework import Variable
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
from paddle.nn.functional import ( from paddle.nn.functional import (
binary_cross_entropy_with_logits, binary_cross_entropy_with_logits,
...@@ -97,7 +97,7 @@ class Bernoulli(exponential_family.ExponentialFamily): ...@@ -97,7 +97,7 @@ class Bernoulli(exponential_family.ExponentialFamily):
check_type( check_type(
probs, probs,
'probs', 'probs',
(float, tensor.Variable), (float, Variable),
self.name, self.name,
) )
...@@ -180,7 +180,7 @@ class Bernoulli(exponential_family.ExponentialFamily): ...@@ -180,7 +180,7 @@ class Bernoulli(exponential_family.ExponentialFamily):
check_type( check_type(
shape, shape,
'shape', 'shape',
(np.ndarray, tensor.Variable, list, tuple), (np.ndarray, Variable, list, tuple),
name, name,
) )
...@@ -259,7 +259,7 @@ class Bernoulli(exponential_family.ExponentialFamily): ...@@ -259,7 +259,7 @@ class Bernoulli(exponential_family.ExponentialFamily):
check_type( check_type(
shape, shape,
'shape', 'shape',
(np.ndarray, tensor.Variable, list, tuple), (np.ndarray, Variable, list, tuple),
name, name,
) )
check_type( check_type(
...@@ -318,7 +318,7 @@ class Bernoulli(exponential_family.ExponentialFamily): ...@@ -318,7 +318,7 @@ class Bernoulli(exponential_family.ExponentialFamily):
""" """
name = self.name + '_cdf' name = self.name + '_cdf'
if not in_dynamic_mode(): if not in_dynamic_mode():
check_type(value, 'value', tensor.Variable, name) check_type(value, 'value', Variable, name)
value = self._check_values_dtype_in_probs(self.probs, value) value = self._check_values_dtype_in_probs(self.probs, value)
probs, value = paddle.broadcast_tensors([self.probs, value]) probs, value = paddle.broadcast_tensors([self.probs, value])
...@@ -356,7 +356,7 @@ class Bernoulli(exponential_family.ExponentialFamily): ...@@ -356,7 +356,7 @@ class Bernoulli(exponential_family.ExponentialFamily):
""" """
name = self.name + '_log_prob' name = self.name + '_log_prob'
if not in_dynamic_mode(): if not in_dynamic_mode():
check_type(value, 'value', tensor.Variable, name) check_type(value, 'value', Variable, name)
value = self._check_values_dtype_in_probs(self.probs, value) value = self._check_values_dtype_in_probs(self.probs, value)
logits, value = paddle.broadcast_tensors([self.logits, value]) logits, value = paddle.broadcast_tensors([self.logits, value])
...@@ -395,7 +395,7 @@ class Bernoulli(exponential_family.ExponentialFamily): ...@@ -395,7 +395,7 @@ class Bernoulli(exponential_family.ExponentialFamily):
""" """
name = self.name + '_prob' name = self.name + '_prob'
if not in_dynamic_mode(): if not in_dynamic_mode():
check_type(value, 'value', tensor.Variable, name) check_type(value, 'value', Variable, name)
return self.log_prob(value).exp(name=name) return self.log_prob(value).exp(name=name)
......
...@@ -17,7 +17,7 @@ import numpy as np ...@@ -17,7 +17,7 @@ import numpy as np
import paddle import paddle
from paddle.distribution import distribution from paddle.distribution import distribution
from paddle.fluid.data_feeder import check_type, convert_dtype from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.layers import tensor from paddle.fluid.framework import Variable
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
from paddle.tensor import multinomial from paddle.tensor import multinomial
...@@ -100,7 +100,7 @@ class Categorical(distribution.Distribution): ...@@ -100,7 +100,7 @@ class Categorical(distribution.Distribution):
check_type( check_type(
logits, logits,
'logits', 'logits',
(np.ndarray, tensor.Variable, list, tuple), (np.ndarray, Variable, list, tuple),
'Categorical', 'Categorical',
) )
......
...@@ -26,7 +26,7 @@ import numpy as np ...@@ -26,7 +26,7 @@ import numpy as np
import paddle import paddle
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype, convert_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, convert_dtype
from paddle.fluid.layers import tensor from paddle.fluid.framework import Variable
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
...@@ -150,7 +150,7 @@ class Distribution: ...@@ -150,7 +150,7 @@ class Distribution:
is_variable = False is_variable = False
is_number = False is_number = False
for arg in args: for arg in args:
if isinstance(arg, tensor.Variable): if isinstance(arg, Variable):
is_variable = True is_variable = True
else: else:
is_number = True is_number = True
...@@ -176,9 +176,7 @@ class Distribution: ...@@ -176,9 +176,7 @@ class Distribution:
tmp = 0.0 tmp = 0.0
for arg in args: for arg in args:
if not isinstance( if not isinstance(arg, (float, list, tuple, np.ndarray, Variable)):
arg, (float, list, tuple, np.ndarray, tensor.Variable)
):
raise TypeError( raise TypeError(
"Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {}".format( "Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {}".format(
type(arg) type(arg)
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import paddle import paddle
from paddle.distribution import distribution from paddle.distribution import distribution
from paddle.fluid.data_feeder import check_type, convert_dtype from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.layers import tensor from paddle.fluid.framework import Variable
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
from paddle.tensor import random from paddle.tensor import random
...@@ -91,13 +91,13 @@ class Normal(distribution.Distribution): ...@@ -91,13 +91,13 @@ class Normal(distribution.Distribution):
check_type( check_type(
loc, loc,
'loc', 'loc',
(int, float, np.ndarray, tensor.Variable, list, tuple), (int, float, np.ndarray, Variable, list, tuple),
'Normal', 'Normal',
) )
check_type( check_type(
scale, scale,
'scale', 'scale',
(int, float, np.ndarray, tensor.Variable, list, tuple), (int, float, np.ndarray, Variable, list, tuple),
'Normal', 'Normal',
) )
...@@ -174,9 +174,9 @@ class Normal(distribution.Distribution): ...@@ -174,9 +174,9 @@ class Normal(distribution.Distribution):
name = self.name + '_sample' name = self.name + '_sample'
if -1 in batch_shape: if -1 in batch_shape:
output_shape = shape + batch_shape output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like( fill_shape = list(batch_shape + shape)
self.loc + self.scale, batch_shape + shape, self.dtype, 0.0 fill_shape[0] = paddle.shape(self.loc + self.scale)[0].item()
) zero_tmp = paddle.full(fill_shape, 0.0, self.dtype)
zero_tmp_reshape = paddle.reshape(zero_tmp, output_shape) zero_tmp_reshape = paddle.reshape(zero_tmp, output_shape)
zero_tmp_shape = paddle.shape(zero_tmp_reshape) zero_tmp_shape = paddle.shape(zero_tmp_reshape)
...@@ -234,9 +234,10 @@ class Normal(distribution.Distribution): ...@@ -234,9 +234,10 @@ class Normal(distribution.Distribution):
name = self.name + '_entropy' name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape) batch_shape = list((self.loc + self.scale).shape)
if -1 in batch_shape: if -1 in batch_shape:
zero_tmp = tensor.fill_constant_batch_size_like( fill_shape = list(batch_shape)
self.loc + self.scale, batch_shape, self.dtype, 0.0 fill_shape[0] = paddle.shape(self.loc + self.scale)[0].item()
) fill_dtype = (self.loc + self.scale).dtype
zero_tmp = paddle.full(fill_shape, 0.0, fill_dtype)
else: else:
zero_tmp = paddle.full(batch_shape, 0.0, self.dtype) zero_tmp = paddle.full(batch_shape, 0.0, self.dtype)
return paddle.add( return paddle.add(
......
...@@ -18,7 +18,7 @@ import paddle ...@@ -18,7 +18,7 @@ import paddle
from paddle import _C_ops from paddle import _C_ops
from paddle.distribution import distribution from paddle.distribution import distribution
from paddle.fluid.data_feeder import check_type, convert_dtype from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.layers import tensor from paddle.fluid.framework import Variable
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
from paddle.tensor import random from paddle.tensor import random
...@@ -105,13 +105,13 @@ class Uniform(distribution.Distribution): ...@@ -105,13 +105,13 @@ class Uniform(distribution.Distribution):
check_type( check_type(
low, low,
'low', 'low',
(int, float, np.ndarray, tensor.Variable, list, tuple), (int, float, np.ndarray, Variable, list, tuple),
'Uniform', 'Uniform',
) )
check_type( check_type(
high, high,
'high', 'high',
(int, float, np.ndarray, tensor.Variable, list, tuple), (int, float, np.ndarray, Variable, list, tuple),
'Uniform', 'Uniform',
) )
...@@ -169,9 +169,9 @@ class Uniform(distribution.Distribution): ...@@ -169,9 +169,9 @@ class Uniform(distribution.Distribution):
batch_shape = list((self.low + self.high).shape) batch_shape = list((self.low + self.high).shape)
if -1 in batch_shape: if -1 in batch_shape:
output_shape = shape + batch_shape output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like( fill_shape = list(batch_shape + shape)
self.low + self.high, batch_shape + shape, self.dtype, 0.0 fill_shape[0] = paddle.shape(self.low + self.high)[0].item()
) zero_tmp = paddle.full(fill_shape, 0.0, self.dtype)
uniform_random_tmp = random.uniform_random_batch_size_like( uniform_random_tmp = random.uniform_random_batch_size_like(
zero_tmp, zero_tmp,
zero_tmp.shape, zero_tmp.shape,
......
...@@ -16,8 +16,6 @@ from . import nn ...@@ -16,8 +16,6 @@ from . import nn
from .nn import * from .nn import *
from . import io from . import io
from .io import * from .io import *
from . import tensor
from .tensor import *
from . import math_op_patch from . import math_op_patch
from .math_op_patch import * from .math_op_patch import *
from .learning_rate_scheduler import * from .learning_rate_scheduler import *
...@@ -27,5 +25,4 @@ from ..layer_helper import LayerHelper ...@@ -27,5 +25,4 @@ from ..layer_helper import LayerHelper
__all__ = [] __all__ = []
__all__ += nn.__all__ __all__ += nn.__all__
__all__ += io.__all__ __all__ += io.__all__
__all__ += tensor.__all__
__all__ += learning_rate_scheduler.__all__ __all__ += learning_rate_scheduler.__all__
...@@ -25,7 +25,6 @@ import numbers ...@@ -25,7 +25,6 @@ import numbers
import paddle import paddle
from . import nn from . import nn
from . import tensor
from ..framework import ( from ..framework import (
default_main_program, default_main_program,
Parameter, Parameter,
...@@ -488,7 +487,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs): ...@@ -488,7 +487,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
learning_rate = base_lr, step_each_epoch=10000, epochs=120) learning_rate = base_lr, step_each_epoch=10000, epochs=120)
""" """
check_type( check_type(
learning_rate, 'learning_rate', (float, tensor.Variable), 'cosine_decay' learning_rate, 'learning_rate', (float, Variable), 'cosine_decay'
) )
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unlessf required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy
import warnings
from ..layer_helper import LayerHelper
from ..framework import (
_current_expected_place,
convert_np_dtype_to_dtype_,
_create_tensor,
in_dygraph_mode,
)
from ..framework import Variable
from ..core import VarDesc
from .. import core
from .layer_function_generator import templatedoc
from ..data_feeder import (
check_variable_and_dtype,
check_type,
check_dtype,
convert_dtype,
)
from paddle.utils import deprecated
from paddle import _C_ops, _legacy_C_ops
__all__ = [
'fill_constant_batch_size_like',
]
@deprecated(since='1.8.0', update_to="paddle.fluid.layers.fill_constant")
@templatedoc()
def fill_constant_batch_size_like(
input,
shape,
dtype,
value,
input_dim_idx=0,
output_dim_idx=0,
force_cpu=False,
):
"""
This OP creates a Tesnor according the shape and dtype, and initializes the
Tensor with the constants provided in ``value``. When the input is LoDTensor
and the input_dim_idx is 0, the output_dim_idx dimension is set to the value
of the batch_size input by the input, the Stop_gradient attribute of the created
Tensor is False by default.
Args:
input(Variable): Tensor which data type is float32, float64, int32 and int64.
shape(list): The shape of Tensor to be created, Tensor's shape may be changed
according the input.
dtype(np.dtype|core.VarDesc.VarType|str): The data type of created Tensor which
can be float32, float64, int32, int64.
value(float|int): The constant value used to initialize the Tensor to be created.
input_dim_idx(int): When the value is 0 and the input is LoDTensor, the output_dim_idx
dimension of the created Tensor is set to the batch_size value of input.
The default value is 0.
output_dim_idx(int): Used to specify which dimension of Tensor is created to be set
the value of batch_size of input Tensor. The default value is 0.
force_cpu(bool): data should be on CPU if it's true, default value is False.
Returns:
Variable: Tensor which will be created according to dtype.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
like = paddle.full(shape=[1,2], fill_value=10, dtype='int64') #like=[[10, 10]]
data = fluid.layers.fill_constant_batch_size_like(
input=like, shape=[1], value=0, dtype='int64') #like=[[10, 10]] data=[0]
"""
if in_dygraph_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
place = _current_expected_place()
if force_cpu:
place = core.CPUPlace()
out = _C_ops.full_batch_size_like(
input, shape, dtype, value, input_dim_idx, output_dim_idx, place
)
out.stop_gradient = True
return out
else:
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_variable_for_type_inference(dtype=dtype)
attrs = {
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx,
'force_cpu': force_cpu,
}
if convert_dtype(dtype) in ['int64', 'int32']:
attrs['str_value'] = str(int(value))
else:
attrs['str_value'] = str(float(value))
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs=attrs,
)
out.stop_gradient = True
return out
...@@ -48,7 +48,6 @@ from .dygraph.learning_rate_scheduler import ( ...@@ -48,7 +48,6 @@ from .dygraph.learning_rate_scheduler import (
_LearningRateEpochDecay, _LearningRateEpochDecay,
) )
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce from functools import reduce
from functools import cmp_to_key from functools import cmp_to_key
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
......
...@@ -600,7 +600,9 @@ class RNNCellBase(Layer): ...@@ -600,7 +600,9 @@ class RNNCellBase(Layer):
class Shape: class Shape:
def __init__(self, shape): def __init__(self, shape):
self.shape = shape if shape[0] == -1 else ([-1] + list(shape)) self.shape = (
list(shape) if shape[0] == -1 else ([-1] + list(shape))
)
# nested structure of shapes # nested structure of shapes
states_shapes = self.state_shape if shape is None else shape states_shapes = self.state_shape if shape is None else shape
...@@ -621,16 +623,35 @@ class RNNCellBase(Layer): ...@@ -621,16 +623,35 @@ class RNNCellBase(Layer):
states_dtypes = paddle.utils.map_structure( states_dtypes = paddle.utils.map_structure(
lambda shape: dtype, states_shapes lambda shape: dtype, states_shapes
) )
fill_shapes = states_shapes
if batch_ref.shape[batch_dim_idx] > 0:
if isinstance(fill_shapes, list):
for s in fill_shapes[0]:
s.shape[0] = batch_ref.shape[batch_dim_idx]
elif isinstance(fill_shapes, tuple):
for s in fill_shapes:
s.shape[0] = batch_ref.shape[batch_dim_idx]
else:
fill_shapes.shape[0] = batch_ref.shape[batch_dim_idx]
else:
if isinstance(fill_shapes, list):
for s in fill_shapes[0]:
s.shape[0] = paddle.shape(batch_ref)[batch_dim_idx].item()
elif isinstance(fill_shapes, tuple):
for s in fill_shapes:
s.shape[0] = paddle.shape(batch_ref)[batch_dim_idx].item()
else:
fill_shapes.shape[0] = paddle.shape(batch_ref)[
batch_dim_idx
].item()
init_states = paddle.utils.map_structure( init_states = paddle.utils.map_structure(
lambda shape, dtype: paddle.fluid.layers.fill_constant_batch_size_like( lambda shape, dtype: paddle.full(
input=batch_ref,
shape=shape.shape, shape=shape.shape,
fill_value=init_value,
dtype=dtype, dtype=dtype,
value=init_value,
input_dim_idx=batch_dim_idx,
), ),
states_shapes, fill_shapes,
states_dtypes, states_dtypes,
) )
return init_states return init_states
...@@ -1534,7 +1555,6 @@ class RNNBase(LayerList): ...@@ -1534,7 +1555,6 @@ class RNNBase(LayerList):
'Reserve': reserve, 'Reserve': reserve,
'DropoutState': self._dropout_state, 'DropoutState': self._dropout_state,
} }
self._helper.append_op( self._helper.append_op(
type="rnn", inputs=inputs, outputs=outputs, attrs=attrs type="rnn", inputs=inputs, outputs=outputs, attrs=attrs
) )
...@@ -1555,11 +1575,15 @@ class RNNBase(LayerList): ...@@ -1555,11 +1575,15 @@ class RNNBase(LayerList):
-1, -1,
self.hidden_size, self.hidden_size,
) )
fill_shape = list(state_shape)
if inputs.shape[batch_index] > 0:
fill_shape[1] = inputs.shape[batch_index]
else:
fill_shape[1] = paddle.shape(inputs)[batch_index].item()
initial_states = tuple( initial_states = tuple(
[ [
paddle.fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0, dtype=dtype)
inputs, state_shape, dtype, 0, batch_index, 1
)
for _ in range(self.state_components) for _ in range(self.state_components)
] ]
) )
......
...@@ -23,7 +23,6 @@ import paddle ...@@ -23,7 +23,6 @@ import paddle
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from ... import tensor from ... import tensor
from ...fluid import layers
from ...framework import ParamAttr from ...framework import ParamAttr
from .. import functional as F from .. import functional as F
from .common import Dropout, Linear from .common import Dropout, Linear
...@@ -342,18 +341,10 @@ class MultiHeadAttention(Layer): ...@@ -342,18 +341,10 @@ class MultiHeadAttention(Layer):
k, v = self.compute_kv(key, value) k, v = self.compute_kv(key, value)
return self.StaticCache(k, v) return self.StaticCache(k, v)
elif value is None: # incremental_state elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like( fill_shape = [-1, self.num_heads, 0, self.head_dim]
input=key, fill_shape[0] = paddle.shape(key)[0].item()
shape=[-1, self.num_heads, 0, self.head_dim], k = paddle.full(fill_shape, 0, key.dtype)
dtype=key.dtype, v = paddle.full(fill_shape, 0, key.dtype)
value=0,
)
v = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0,
)
return self.Cache(k, v) return self.Cache(k, v)
else: else:
# incremental_state with initial value, mainly for usage like UniLM # incremental_state with initial value, mainly for usage like UniLM
......
...@@ -101,10 +101,9 @@ def mlp_forward(train_program, start_program): ...@@ -101,10 +101,9 @@ def mlp_forward(train_program, start_program):
label = static.data( label = static.data(
name="label", shape=[batch_size, 1], dtype='float32' name="label", shape=[batch_size, 1], dtype='float32'
) )
fill_shape = [batch_size]
fill_constant_out = paddle.fluid.layers.fill_constant_batch_size_like( fill_shape[0] = input.shape[0]
input=input, shape=[batch_size], value=1, dtype="int32" fill_constant_out = paddle.full(fill_shape, 1, dtype="int32")
)
embedding = paddle.nn.Embedding(10, hidden_size, sparse=True) embedding = paddle.nn.Embedding(10, hidden_size, sparse=True)
embedding_out = embedding(fill_constant_out) embedding_out = embedding(fill_constant_out)
......
...@@ -75,9 +75,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -75,9 +75,9 @@ class TestDistOpCost(unittest.TestCase):
auto.shard_tensor( auto.shard_tensor(
x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None] x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None]
) )
tmp = paddle.fluid.layers.fill_constant_batch_size_like( fill_shape = [2, 8]
input=x, shape=[2, 8], value=1, dtype='float32' fill_shape[0] = x.shape[0]
) tmp = paddle.full(fill_shape, fill_value=1, dtype='float32')
weight_attr = paddle.ParamAttr() weight_attr = paddle.ParamAttr()
linear = paddle.nn.Linear(8, 1, weight_attr=weight_attr) linear = paddle.nn.Linear(8, 1, weight_attr=weight_attr)
linear_out = linear(x) linear_out = linear(x)
...@@ -97,6 +97,8 @@ class TestDistOpCost(unittest.TestCase): ...@@ -97,6 +97,8 @@ class TestDistOpCost(unittest.TestCase):
op.type != "matmul_v2" op.type != "matmul_v2"
and op.type != "matmul_v2_grad" and op.type != "matmul_v2_grad"
and op.type != "sgd" and op.type != "sgd"
and op.type != "shape"
and op.type != "slice"
): ):
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -137,9 +139,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -137,9 +139,9 @@ class TestDistOpCost(unittest.TestCase):
["x", None], ["x", None],
) )
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( fill_shape = [4]
input=x, shape=[4], value=1, dtype='int32' fill_shape[0] = x.shape[0]
) tmp = paddle.full(shape=fill_shape, fill_value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8) embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp) out = embedding(tmp)
# row parallel embedding # row parallel embedding
...@@ -206,23 +208,24 @@ class TestDistOpCost(unittest.TestCase): ...@@ -206,23 +208,24 @@ class TestDistOpCost(unittest.TestCase):
cluster = Cluster() cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2) cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op) if op.type != "shape" and op.type != "slice":
op_dist_attr = dist_op.dist_attr dist_op = dist_context.get_dist_op_for_program(op)
processes = op_dist_attr.process_mesh.process_ids op_dist_attr = dist_op.dist_attr
if is_elementwise_op(op.type): processes = op_dist_attr.process_mesh.process_ids
container = get_distributed_operator_impl_container( if is_elementwise_op(op.type):
"elementwise" container = get_distributed_operator_impl_container(
) "elementwise"
else: )
container = get_distributed_operator_impl_container( else:
op_dist_attr.impl_type container = get_distributed_operator_impl_container(
) op_dist_attr.impl_type
)
dist_impl = container.impls[op_dist_attr.impl_idx] dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost( dist_op_cost = dist_impl.calc_cost(
op.attr('op_role'), dist_op, dist_context, cluster op.attr('op_role'), dist_op, dist_context, cluster
) )
self.assertTrue(dist_op_cost) self.assertTrue(dist_op_cost)
def test_dist_op_cost_part3(self): def test_dist_op_cost_part3(self):
def make_program(): def make_program():
...@@ -245,9 +248,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -245,9 +248,9 @@ class TestDistOpCost(unittest.TestCase):
["x", None], ["x", None],
) )
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( fill_shape = [4]
input=x, shape=[4], value=1, dtype='int32' fill_shape[0] = x.shape[0]
) tmp = paddle.full(shape=fill_shape, fill_value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8) embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp) out = embedding(tmp)
# row parallel embedding # row parallel embedding
...@@ -315,23 +318,24 @@ class TestDistOpCost(unittest.TestCase): ...@@ -315,23 +318,24 @@ class TestDistOpCost(unittest.TestCase):
cluster = Cluster() cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2) cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op) if op.type != "shape" and op.type != "slice":
op_dist_attr = dist_op.dist_attr dist_op = dist_context.get_dist_op_for_program(op)
processes = op_dist_attr.process_mesh.process_ids op_dist_attr = dist_op.dist_attr
if is_elementwise_op(op.type): processes = op_dist_attr.process_mesh.process_ids
container = get_distributed_operator_impl_container( if is_elementwise_op(op.type):
"elementwise" container = get_distributed_operator_impl_container(
) "elementwise"
else: )
container = get_distributed_operator_impl_container( else:
op_dist_attr.impl_type container = get_distributed_operator_impl_container(
) op_dist_attr.impl_type
)
dist_impl = container.impls[op_dist_attr.impl_idx] dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost( dist_op_cost = dist_impl.calc_cost(
op.attr('op_role'), dist_op, dist_context, cluster op.attr('op_role'), dist_op, dist_context, cluster
) )
self.assertTrue(dist_op_cost) self.assertTrue(dist_op_cost)
def test_dist_op_cost_part4(self): def test_dist_op_cost_part4(self):
def make_program(): def make_program():
...@@ -353,9 +357,9 @@ class TestDistOpCost(unittest.TestCase): ...@@ -353,9 +357,9 @@ class TestDistOpCost(unittest.TestCase):
["x", None], ["x", None],
) )
# embedding # embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like( fill_shape = [4]
input=x, shape=[4], value=1, dtype='int32' fill_shape[0] = x.shape[0]
) tmp = paddle.full(shape=fill_shape, fill_value=1, dtype='int32')
embedding = paddle.nn.Embedding(10, 8) embedding = paddle.nn.Embedding(10, 8)
out = embedding(tmp) out = embedding(tmp)
# row parallel embedding # row parallel embedding
...@@ -423,23 +427,24 @@ class TestDistOpCost(unittest.TestCase): ...@@ -423,23 +427,24 @@ class TestDistOpCost(unittest.TestCase):
cluster = Cluster() cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2) cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
dist_op = dist_context.get_dist_op_for_program(op) if op.type != "shape" and op.type != "slice":
op_dist_attr = dist_op.dist_attr dist_op = dist_context.get_dist_op_for_program(op)
processes = op_dist_attr.process_mesh.process_ids op_dist_attr = dist_op.dist_attr
if is_elementwise_op(op.type): processes = op_dist_attr.process_mesh.process_ids
container = get_distributed_operator_impl_container( if is_elementwise_op(op.type):
"elementwise" container = get_distributed_operator_impl_container(
) "elementwise"
else: )
container = get_distributed_operator_impl_container( else:
op_dist_attr.impl_type container = get_distributed_operator_impl_container(
) op_dist_attr.impl_type
)
dist_impl = container.impls[op_dist_attr.impl_idx] dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost( dist_op_cost = dist_impl.calc_cost(
op.attr('op_role'), dist_op, dist_context, cluster op.attr('op_role'), dist_op, dist_context, cluster
) )
self.assertTrue(dist_op_cost) self.assertTrue(dist_op_cost)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -145,8 +145,18 @@ def get_program(): ...@@ -145,8 +145,18 @@ def get_program():
auto.shard_tensor(label, _g_process_mesh, [None, None, None]) auto.shard_tensor(label, _g_process_mesh, [None, None, None])
# fill constant bsz like # fill constant bsz like
tmp = paddle.fluid.layers.fill_constant_batch_size_like( block = train_program.current_block()
input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0 fill_shape = [-1, 16, 0, 48]
tmp = block.create_var(name='tmp', dtype='float32')
block.append_op(
type='fill_constant_batch_size_like',
outputs={'Out': [tmp]},
inputs={'Input': [input]},
attrs={
'shape': fill_shape,
'value': 0,
},
stop_gradient=True,
) )
auto.shard_tensor(tmp, _g_process_mesh, [None, 'x', None, None]) auto.shard_tensor(tmp, _g_process_mesh, [None, 'x', None, None])
...@@ -369,7 +379,6 @@ class TestMLP(unittest.TestCase): ...@@ -369,7 +379,6 @@ class TestMLP(unittest.TestCase):
train_program, start_program, dist_context train_program, start_program, dist_context
) )
dist_context.block_state.parse_forward_blocks(train_program) dist_context.block_state.parse_forward_blocks(train_program)
dist_main_prog, dist_startup_prog = partition( dist_main_prog, dist_startup_prog = partition(
train_program, start_program, dist_context train_program, start_program, dist_context
) )
...@@ -388,8 +397,8 @@ class TestMLP(unittest.TestCase): ...@@ -388,8 +397,8 @@ class TestMLP(unittest.TestCase):
self.assertTrue("c_allreduce_sum" in sub_block_ops) self.assertTrue("c_allreduce_sum" in sub_block_ops)
# test fill_constant_batch_size_like # test fill_constant_batch_size_like
self.assertIsNotNone(fill_op) self.assertIsNotNone(fill_op)
ref_shape = [-1, 8, 0, 48] ref_shape = [-1, 8, 0, 48]
shape = fill_op.attr("shape") shape = fill_op.attr("shape")
self.assertTrue(ref_shape == shape) self.assertTrue(ref_shape == shape)
......
...@@ -18,7 +18,6 @@ import paddle ...@@ -18,7 +18,6 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, tensor from paddle import nn, tensor
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
paddle.enable_static() paddle.enable_static()
...@@ -212,18 +211,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -212,18 +211,10 @@ class MultiHeadAttention(nn.Layer):
k, v = self.compute_kv(key, value) k, v = self.compute_kv(key, value)
return self.StaticCache(k, v) return self.StaticCache(k, v)
elif value is None: # incremental_state elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like( fill_shape = [-1, self.num_heads, 0, self.head_dim]
input=key, fill_shape[0] = paddle.shape(key)[0].item()
shape=[-1, self.num_heads, 0, self.head_dim], k = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype)
dtype=key.dtype, v = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype)
value=0,
)
v = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0,
)
return self.Cache(k, v) return self.Cache(k, v)
else: else:
# incremental_state with initial value, mainly for usage like UniLM # incremental_state with initial value, mainly for usage like UniLM
......
...@@ -68,17 +68,17 @@ def get_acc(cos_q_nt, cos_q_pt, batch_size): ...@@ -68,17 +68,17 @@ def get_acc(cos_q_nt, cos_q_pt, batch_size):
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=margin, dtype='float32'),
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32'
),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -23,7 +23,6 @@ from paddle.distributed.auto_parallel.static.dist_context import ( ...@@ -23,7 +23,6 @@ from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
) )
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
paddle.enable_static() paddle.enable_static()
...@@ -172,18 +171,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -172,18 +171,10 @@ class MultiHeadAttention(nn.Layer):
k, v = self.compute_kv(key, value) k, v = self.compute_kv(key, value)
return self.StaticCache(k, v) return self.StaticCache(k, v)
elif value is None: # incremental_state elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like( fill_shape = [-1, self.num_heads, 0, self.head_dim]
input=key, fill_shape[0] = paddle.shape(key)[0].item()
shape=[-1, self.num_heads, 0, self.head_dim], k = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype)
dtype=key.dtype, v = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype)
value=0,
)
v = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0,
)
return self.Cache(k, v) return self.Cache(k, v)
else: else:
# incremental_state with initial value, mainly for usage like UniLM # incremental_state with initial value, mainly for usage like UniLM
......
...@@ -28,7 +28,6 @@ from paddle.distributed.auto_parallel.static.process_group import ( ...@@ -28,7 +28,6 @@ from paddle.distributed.auto_parallel.static.process_group import (
) )
from paddle.distributed.auto_parallel.static.utils import _get_comm_group from paddle.distributed.auto_parallel.static.utils import _get_comm_group
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
paddle.enable_static() paddle.enable_static()
...@@ -218,18 +217,10 @@ class MultiHeadAttention(nn.Layer): ...@@ -218,18 +217,10 @@ class MultiHeadAttention(nn.Layer):
k, v = self.compute_kv(key, value) k, v = self.compute_kv(key, value)
return self.StaticCache(k, v) return self.StaticCache(k, v)
elif value is None: # incremental_state elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like( fill_shape = [-1, self.num_heads, 0, self.head_dim]
input=key, fill_shape[0] = paddle.shape(key)[0].item()
shape=[-1, self.num_heads, 0, self.head_dim], k = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype)
dtype=key.dtype, v = paddle.full(shape=fill_shape, fill_value=0, dtype=key.dtype)
value=0,
)
v = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0,
)
return self.Cache(k, v) return self.Cache(k, v)
else: else:
# incremental_state with initial value, mainly for usage like UniLM # incremental_state with initial value, mainly for usage like UniLM
......
...@@ -49,17 +49,19 @@ class TestPSMinimize(unittest.TestCase): ...@@ -49,17 +49,19 @@ class TestPSMinimize(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -52,17 +52,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -52,17 +52,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -53,17 +53,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -53,17 +53,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -52,17 +52,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -52,17 +52,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -49,17 +49,19 @@ class TestPSPassWithBow(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -47,17 +47,19 @@ class TestSPMT(unittest.TestCase): ...@@ -47,17 +47,19 @@ class TestSPMT(unittest.TestCase):
return acc return acc
def get_loss(cos_q_pt, cos_q_nt): def get_loss(cos_q_pt, cos_q_nt):
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(cos_q_pt)[0].item()
loss_op1 = paddle.subtract( loss_op1 = paddle.subtract(
fluid.layers.fill_constant_batch_size_like( paddle.full(
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32' shape=fill_shape, fill_value=margin, dtype='float32'
), ),
cos_q_pt, cos_q_pt,
) )
loss_op2 = paddle.add(loss_op1, cos_q_nt) loss_op2 = paddle.add(loss_op1, cos_q_nt)
fill_shape = [-1, 1]
fill_shape[0] = paddle.shape(loss_op2)[0].item()
loss_op3 = paddle.maximum( loss_op3 = paddle.maximum(
fluid.layers.fill_constant_batch_size_like( paddle.full(shape=fill_shape, fill_value=0.0, dtype='float32'),
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'
),
loss_op2, loss_op2,
) )
avg_cost = paddle.mean(loss_op3) avg_cost = paddle.mean(loss_op3)
......
...@@ -422,11 +422,15 @@ class TestFakeInit(TranspilerTest): ...@@ -422,11 +422,15 @@ class TestFakeInit(TranspilerTest):
neg_matmul_re = paddle.reshape(neg_matmul, shape=[-1, neg_num]) neg_matmul_re = paddle.reshape(neg_matmul, shape=[-1, neg_num])
neg_logits = paddle.add(neg_matmul_re, neg_emb_b_vec) neg_logits = paddle.add(neg_matmul_re, neg_emb_b_vec)
# nce loss # nce loss
label_ones = fluid.layers.fill_constant_batch_size_like( fill_shape = [-1, 1]
true_logits, shape=[-1, 1], value=1.0, dtype='float32' fill_shape[0] = paddle.shape(true_logits)[0].item()
) label_ones = paddle.full(
label_zeros = fluid.layers.fill_constant_batch_size_like( shape=fill_shape, fill_value=1.0, dtype='float32'
true_logits, shape=[-1, neg_num], value=0.0, dtype='float32' )
fill_shape = [-1, neg_num]
fill_shape[0] = paddle.shape(true_logits)[0].item()
label_zeros = paddle.full(
shape=fill_shape, fill_value=0.0, dtype='float32'
) )
true_xent = paddle.nn.functional.binary_cross_entropy_with_logits( true_xent = paddle.nn.functional.binary_cross_entropy_with_logits(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
from paddle.fluid.framework import convert_np_dtype_to_dtype_
paddle.enable_static()
def fill_constant_batch_size_like(
input,
shape,
value,
data_type,
input_dim_idx=0,
output_dim_idx=0,
force_cpu=False,
):
return paddle.fluid.layers.fill_constant_batch_size_like(
input, shape, data_type, value, input_dim_idx, output_dim_idx, force_cpu
)
class TestFillConstantBatchSizeLike1(OpTest):
# test basic
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.python_api = fill_constant_batch_size_like
self.init_dtype()
self.init_data()
input = np.zeros(self.shape)
out = np.full_like(input, self.value, self.dtype)
self.inputs = {'Input': input}
self.outputs = {'Out': out}
self.attrs = {
'shape': self.shape,
'dtype': convert_np_dtype_to_dtype_(self.dtype),
'value': self.value,
'input_dim_idx': self.input_dim_idx,
'output_dim_idx': self.output_dim_idx,
'force_cpu': self.force_cpu,
}
def init_dtype(self):
self.dtype = np.float32
def init_data(self):
self.shape = [10, 10]
self.value = 100
self.input_dim_idx = 0
self.output_dim_idx = 0
self.force_cpu = False
def test_check_output(self):
self.check_output()
class TestFillConstantBatchSizeLikeFP16Op(TestFillConstantBatchSizeLike1):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda() or not core.supports_bfloat16(),
"core is not compiled with CUDA or place do not support bfloat16",
)
class TestFillConstantBatchSizeLikeBF16Op(OpTest):
# test bf16
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.python_api = fill_constant_batch_size_like
self.init_data()
input = np.zeros(self.shape).astype("float32")
input_bf16 = convert_float_to_uint16(input)
out = np.full_like(input, self.value, np.float32)
out_bf16 = convert_float_to_uint16(out)
self.inputs = {'Input': input_bf16}
self.outputs = {'Out': out_bf16}
self.attrs = {
'shape': self.shape,
'dtype': convert_np_dtype_to_dtype_(self.dtype),
'value': self.value,
'input_dim_idx': self.input_dim_idx,
'output_dim_idx': self.output_dim_idx,
'force_cpu': self.force_cpu,
}
def init_data(self):
self.shape = [10, 10]
self.dtype = np.uint16
self.value = 100
self.input_dim_idx = 0
self.output_dim_idx = 0
self.force_cpu = False
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
...@@ -2133,16 +2133,6 @@ class TestBook(LayerTest): ...@@ -2133,16 +2133,6 @@ class TestBook(LayerTest):
) )
return out return out
def test_fill_constant_batch_size_like(self):
with self.static_graph():
like = paddle.tensor.fill_constant(
shape=[1, 200], value=10, dtype='int64'
)
out = layers.fill_constant_batch_size_like(
input=like, shape=[2, 3300], value=1315454564656, dtype='int64'
)
return out
def test_shuffle_batch(self): def test_shuffle_batch(self):
# TODO(minqiyang): dygraph do not support lod now # TODO(minqiyang): dygraph do not support lod now
with self.static_graph(): with self.static_graph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册