Skip to content

Commit 937535d

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Allow dictionaries to overwrite entries with #fairseq:overwrite comment (#1073)
Summary: [This commit](dd1298e) made it so that duplicate entries in a dictionary are ignored. Unfortunately the Camembert model depends on overwriting `<unk>`, `<s>` and `</s>`. The proposed solution here is to allow the dictionary to have entries like: ``` <unk> 999 #fairseq:overwrite <s> 999 #fairseq:overwrite </s> 999 #fairseq:overwrite , 999 ▁de 999 . 999 (...) ``` These will preserve the old overwriting behavior. Thus we can release a new `camembert.v0.tar.gz` with a dictionary like above and it works. Pull Request resolved: fairinternal/fairseq-py#1073 Reviewed By: kahne Differential Revision: D20284569 Pulled By: myleott fbshipit-source-id: bf78fbff13c94bf8a6485cbdda62305ddc30c056
1 parent 3dd221c commit 937535d

File tree

2 files changed

+70
-8
lines changed

2 files changed

+70
-8
lines changed

fairseq/data/dictionary.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def unk_string(self, escape=False):
9191
else:
9292
return self.unk_word
9393

94-
def add_symbol(self, word, n=1):
94+
def add_symbol(self, word, n=1, overwrite=False):
9595
"""Adds a word to the dictionary"""
96-
if word in self.indices:
96+
if word in self.indices and not overwrite:
9797
idx = self.indices[word]
9898
self.count[idx] = self.count[idx] + n
9999
return idx
@@ -215,15 +215,31 @@ def add_from_file(self, f):
215215

216216
lines = f.readlines()
217217
indices_start_line = self._load_meta(lines)
218+
218219
for line in lines[indices_start_line:]:
219-
idx = line.rfind(" ")
220-
if idx == -1:
220+
try:
221+
line, field = line.rstrip().rsplit(" ", 1)
222+
if field == "#fairseq:overwrite":
223+
overwrite = True
224+
line, field = line.rsplit(" ", 1)
225+
else:
226+
overwrite = False
227+
count = int(field)
228+
word = line
229+
if word in self and not overwrite:
230+
raise RuntimeError(
231+
"Duplicate word found when loading Dictionary: '{}'. "
232+
"Duplicate words can overwrite earlier ones by adding the "
233+
"#fairseq:overwrite flag at the end of the corresponding row "
234+
"in the dictionary file. If using the Camembert model, please "
235+
"download an updated copy of the model file."
236+
.format(word)
237+
)
238+
self.add_symbol(word, n=count, overwrite=overwrite)
239+
except ValueError:
221240
raise ValueError(
222-
"Incorrect dictionary format, expected '<token> <cnt>'"
241+
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
223242
)
224-
word = line[:idx]
225-
count = int(line[idx + 1 :])
226-
self.add_symbol(word, n=count)
227243

228244
def _save(self, f, kv_iterator):
229245
if isinstance(f, str):

tests/test_dictionary.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import io
67
import tempfile
78
import unittest
89

@@ -65,6 +66,51 @@ def assertMatch(ids, ref_ids):
6566
assertMatch(reload_ids, ref_ids2)
6667
assertMatch(finalized_ids, reload_ids)
6768

69+
def test_overwrite(self):
70+
# for example, Camembert overwrites <unk>, <s> and </s>
71+
dict_file = io.StringIO(
72+
"<unk> 999 #fairseq:overwrite\n"
73+
"<s> 999 #fairseq:overwrite\n"
74+
"</s> 999 #fairseq:overwrite\n"
75+
", 999\n"
76+
"▁de 999\n"
77+
)
78+
d = Dictionary()
79+
d.add_from_file(dict_file)
80+
self.assertEqual(d.index('<pad>'), 1)
81+
self.assertEqual(d.index('foo'), 3)
82+
self.assertEqual(d.index('<unk>'), 4)
83+
self.assertEqual(d.index('<s>'), 5)
84+
self.assertEqual(d.index('</s>'), 6)
85+
self.assertEqual(d.index(','), 7)
86+
self.assertEqual(d.index('▁de'), 8)
87+
88+
def test_no_overwrite(self):
89+
# for example, Camembert overwrites <unk>, <s> and </s>
90+
dict_file = io.StringIO(
91+
"<unk> 999\n"
92+
"<s> 999\n"
93+
"</s> 999\n"
94+
", 999\n"
95+
"▁de 999\n"
96+
)
97+
d = Dictionary()
98+
with self.assertRaisesRegex(RuntimeError, 'Duplicate'):
99+
d.add_from_file(dict_file)
100+
101+
def test_space(self):
102+
# for example, character models treat space as a symbol
103+
dict_file = io.StringIO(
104+
" 999\n"
105+
"a 999\n"
106+
"b 999\n"
107+
)
108+
d = Dictionary()
109+
d.add_from_file(dict_file)
110+
self.assertEqual(d.index(' '), 4)
111+
self.assertEqual(d.index('a'), 5)
112+
self.assertEqual(d.index('b'), 6)
113+
68114

69115
if __name__ == '__main__':
70116
unittest.main()

0 commit comments

Comments
 (0)