...
 
Commits (6)
    https://gitcode.net/weixin_42428077/models/-/commit/bfb3e7e653ec97adc5ff6a56bae7c6d55a8bf350 Internal change 2023-07-05T10:27:10-07:00 Chaochao Yan allenyan@google.com PiperOrigin-RevId: 545715680 https://gitcode.net/weixin_42428077/models/-/commit/0b3b5c0656e0293e7ed10f99d374b770569c2bef Internal change 2023-07-05T13:24:29-07:00 A. Unique TensorFlower gardener@tensorflow.org PiperOrigin-RevId: 545768729 https://gitcode.net/weixin_42428077/models/-/commit/e7ed21934312ca749e13e288b7dfafcb6f4b5b22 Internal change 2023-07-05T14:55:38-07:00 Liangzhe Yuan lzyuan@google.com PiperOrigin-RevId: 545793613 https://gitcode.net/weixin_42428077/models/-/commit/bb974b664022f44ea91d98e253a1685536700cad Remove tf-text version requirement. 2023-07-05T20:29:05-07:00 Yuexin Wu crickwu@google.com PiperOrigin-RevId: 545856817 https://gitcode.net/weixin_42428077/models/-/commit/307271f1c8047e219d40ef9f37e6c02ee46d207a Internal change 2023-07-06T10:38:26-07:00 A. Unique TensorFlower gardener@tensorflow.org PiperOrigin-RevId: 546030270 https://gitcode.net/weixin_42428077/models/-/commit/abb7ed6a4628118042876c17ba48c559e0fb5025 Remove tf-text version requirement. 2023-07-06T13:47:13-07:00 Yuexin Wu crickwu@google.com PiperOrigin-RevId: 546081639
......@@ -28,7 +28,7 @@ import yaml
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(
r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
(?P<name>[a-zA-Z][\w\.]*)(?P<bracketed_index>\[?[0-9]*\]?) # variable name: "var" or "x" followed by optional index: "[0]" or "[23]"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
|
......@@ -223,7 +223,7 @@ class ParamsDict(object):
"""Validate the parameters consistency based on the restrictions.
This method validates the internal consistency using the pre-defined list of
restrictions. A restriction is defined as a string which specfiies a binary
restrictions. A restriction is defined as a string which specifies a binary
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
'>='}. Note that the meaning of these operators are consistent with the
underlying Python immplementation. Users should make sure the define
......@@ -385,6 +385,8 @@ def nested_csv_str_to_json_str(csv_str):
if not csv_str:
return ''
array_param_map = collections.defaultdict(str)
max_index_map = collections.defaultdict(str)
formatted_entries = []
nested_map = collections.defaultdict(list)
pos = 0
......@@ -398,6 +400,27 @@ def nested_csv_str_to_json_str(csv_str):
m_dict = m.groupdict()
name = m_dict['name']
v = m_dict['val']
bracketed_index = m_dict['bracketed_index']
# If we reach the name of the array.
if bracketed_index and '.' not in name:
# Extract the array's index by removing '[' and ']'
index = int(bracketed_index[1:-1])
if '.' in v:
numeric_val = float(v)
else:
numeric_val = int(v)
# Add the value to the array.
if name not in array_param_map:
max_index_map[name] = index
array_param_map[name] = [None] * (index + 1)
array_param_map[name][index] = numeric_val
elif index < max_index_map[name]:
array_param_map[name][index] = numeric_val
else:
array_param_map[name] += [None] * (index - max_index_map[name])
array_param_map[name][index] = numeric_val
max_index_map[name] = index
continue
# If a GCS path (e.g. gs://...) is provided, wrap this in quotes
# as yaml.load would otherwise throw an exception
......@@ -407,7 +430,10 @@ def nested_csv_str_to_json_str(csv_str):
name_nested = name.split('.')
if len(name_nested) > 1:
grouping = name_nested[0]
value = '.'.join(name_nested[1:]) + '=' + v
if bracketed_index:
value = '.'.join(name_nested[1:]) + bracketed_index + '=' + v
else:
value = '.'.join(name_nested[1:]) + '=' + v
nested_map[grouping].append(value)
else:
formatted_entries.append('%s : %s' % (name, v))
......@@ -416,6 +442,13 @@ def nested_csv_str_to_json_str(csv_str):
value = ','.join(value)
value = nested_csv_str_to_json_str(value)
formatted_entries.append('%s : %s' % (grouping, value))
# Add array parameters and check that the array is fully initialized.
for name in array_param_map:
if any(v is None for v in array_param_map[name]):
raise ValueError('Did not pass all values of array: %s' % name)
formatted_entries.append('%s : %s' % (name, array_param_map[name]))
return '{' + ', '.join(formatted_entries) + '}'
......
......@@ -176,7 +176,7 @@ class ParamsDictTest(tf.test.TestCase):
# Valid rule.
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate.
# Overriding violates the existing rule, raise error upon validate.
params.override({'a': 11})
with self.assertRaises(KeyError):
params.validate()
......@@ -393,6 +393,23 @@ class IOTest(tf.test.TestCase):
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_int_array_param_nested_csv_str_to_json_str(self):
csv_str = 'a.b[2]=3,a.b[0]=1,a.b[1]=2'
json_str = '{a : {b : [1, 2, 3]}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_float_array_param_nested_csv_str_to_json_str(self):
csv_str = 'a.b[1]=3.45,a.b[2]=1.32,a.b[0]=2.232'
json_str = '{a : {b : [2.232, 3.45, 1.32]}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_incomplete_array_param_nested_csv_str_to_json_str(self):
csv_str = 'a.b[0]=1,a.b[2]=2'
self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
csv_str)
def test_csv_str_load_supported_datatypes(self):
csv_str = 'a=1,b=2.,c=[1,2,3],d=\'hello, there\',e=\"Hi.\"'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
......
......@@ -20,8 +20,8 @@ import sys
from setuptools import find_packages
from setuptools import setup
version = '2.12.0'
tf_version = '2.12.0' # Major version.
version = '2.13.0'
tf_version = '2.13.0' # Major version.
project_name = 'tf-models-official'
......
......@@ -92,13 +92,17 @@ class _ViTAdamW(nlp_optimization.AdamWeightDecay):
and self._vars_substr is not None
and self._layers_idx is not None
):
is_decayed = False
for var_substr, idx in zip(self._vars_substr, self._layers_idx):
if var_substr in var.name:
decay_factor = self._layer_decay ** (self._max_idx - idx)
lr_t = lr_t * decay_factor
is_decayed = True
logging.debug(
'Applying layer-wise lr decay: %s: %f', var.name, decay_factor)
break
if not is_decayed:
logging.debug('Ignore layer-wise lr decay: %s', var.name)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
var_device, var_dtype = var.device, var.dtype.base_dtype
......@@ -155,13 +159,17 @@ class _ViTAdamW(nlp_optimization.AdamWeightDecay):
and self._vars_substr is not None
and self._layers_idx is not None
):
is_decayed = False
for var_substr, idx in zip(self._vars_substr, self._layers_idx):
if var_substr in var.name:
decay_factor = self._layer_decay ** (self._max_idx - idx)
lr_t = lr_t * decay_factor
is_decayed = True
logging.debug(
'Applying layer-wise lr decay: %s: %f', var.name, decay_factor)
break
if not is_decayed:
logging.debug('Ignore layer-wise lr decay: %s', var.name)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
var_device, var_dtype = var.device, var.dtype.base_dtype
......
......@@ -545,7 +545,7 @@ class MaskRCNNTask(base_task.Task):
state.update(visualization_utils.update_detection_state(step_outputs))
# TODO(allenyan): Mapping `detection_masks` (w.r.t. the `gt_boxes`) back
# to full masks (w.r.t. the image). Disable mask visualization fow now.
state.pop('detection_masks')
state.pop('detection_masks', None)
if not state:
# Create an arbitrary state to indicate it's not the first step in the
......