|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
| 6 | +import io |
6 | 7 | import tempfile
|
7 | 8 | import unittest
|
8 | 9 |
|
@@ -65,6 +66,51 @@ def assertMatch(ids, ref_ids):
|
65 | 66 | assertMatch(reload_ids, ref_ids2)
|
66 | 67 | assertMatch(finalized_ids, reload_ids)
|
67 | 68 |
|
| 69 | + def test_overwrite(self): |
| 70 | + # for example, Camembert overwrites <unk>, <s> and </s> |
| 71 | + dict_file = io.StringIO( |
| 72 | + "<unk> 999 #fairseq:overwrite\n" |
| 73 | + "<s> 999 #fairseq:overwrite\n" |
| 74 | + "</s> 999 #fairseq:overwrite\n" |
| 75 | + ", 999\n" |
| 76 | + "▁de 999\n" |
| 77 | + ) |
| 78 | + d = Dictionary() |
| 79 | + d.add_from_file(dict_file) |
| 80 | + self.assertEqual(d.index('<pad>'), 1) |
| 81 | + self.assertEqual(d.index('foo'), 3) |
| 82 | + self.assertEqual(d.index('<unk>'), 4) |
| 83 | + self.assertEqual(d.index('<s>'), 5) |
| 84 | + self.assertEqual(d.index('</s>'), 6) |
| 85 | + self.assertEqual(d.index(','), 7) |
| 86 | + self.assertEqual(d.index('▁de'), 8) |
| 87 | + |
| 88 | + def test_no_overwrite(self): |
| 89 | + # for example, Camembert overwrites <unk>, <s> and </s> |
| 90 | + dict_file = io.StringIO( |
| 91 | + "<unk> 999\n" |
| 92 | + "<s> 999\n" |
| 93 | + "</s> 999\n" |
| 94 | + ", 999\n" |
| 95 | + "▁de 999\n" |
| 96 | + ) |
| 97 | + d = Dictionary() |
| 98 | + with self.assertRaisesRegex(RuntimeError, 'Duplicate'): |
| 99 | + d.add_from_file(dict_file) |
| 100 | + |
| 101 | + def test_space(self): |
| 102 | + # for example, character models treat space as a symbol |
| 103 | + dict_file = io.StringIO( |
| 104 | + " 999\n" |
| 105 | + "a 999\n" |
| 106 | + "b 999\n" |
| 107 | + ) |
| 108 | + d = Dictionary() |
| 109 | + d.add_from_file(dict_file) |
| 110 | + self.assertEqual(d.index(' '), 4) |
| 111 | + self.assertEqual(d.index('a'), 5) |
| 112 | + self.assertEqual(d.index('b'), 6) |
| 113 | + |
68 | 114 |
|
69 | 115 | if __name__ == '__main__':
|
70 | 116 | unittest.main()
|
0 commit comments