算法-字符串

字符串排序|单词查找树|子字符串查找|正则表达式|数据压缩

字符串排序

字母表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
public class Alphabet 
{

/**
* The binary alphabet { 0, 1 }.
*/
public static final Alphabet BINARY = new Alphabet("01");

/**
* The octal alphabet { 0, 1, 2, 3, 4, 5, 6, 7 }.
*/
public static final Alphabet OCTAL = new Alphabet("01234567");

/**
* The decimal alphabet { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }.
*/
public static final Alphabet DECIMAL = new Alphabet("0123456789");

/**
* The hexadecimal alphabet { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F }.
*/
public static final Alphabet HEXADECIMAL = new Alphabet("0123456789ABCDEF");

/**
* The DNA alphabet { A, C, T, G }.
*/
public static final Alphabet DNA = new Alphabet("ACGT");

/**
* The lowercase alphabet { a, b, c, ..., z }.
*/
public static final Alphabet LOWERCASE = new Alphabet("abcdefghijklmnopqrstuvwxyz");

/**
* The uppercase alphabet { A, B, C, ..., Z }.
*/

public static final Alphabet UPPERCASE = new Alphabet("ABCDEFGHIJKLMNOPQRSTUVWXYZ");

/**
* The protein alphabet { A, C, D, E, F, G, H, I, K, L, M, N, P, Q, R, S, T, V, W, Y }.
*/
public static final Alphabet PROTEIN = new Alphabet("ACDEFGHIKLMNPQRSTVWY");

/**
* The base-64 alphabet (64 characters).
*/
public static final Alphabet BASE64 = new Alphabet("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");

/**
* The ASCII alphabet (0-127).
*/
public static final Alphabet ASCII = new Alphabet(128);

/**
* The extended ASCII alphabet (0-255).
*/
public static final Alphabet EXTENDED_ASCII = new Alphabet(256);

/**
* The Unicode 16 alphabet (0-65,535).
*/
public static final Alphabet UNICODE16 = new Alphabet(65536);


private char[] alphabet; // the characters in the alphabet
private int[] inverse; // indices
private final int R; // the radix of the alphabet

/**
* Initializes a new alphabet from the given set of characters.
*
* @param alpha the set of characters
*/
public Alphabet(String alpha)
{

// check that alphabet contains no duplicate chars
boolean[] unicode = new boolean[Character.MAX_VALUE];
for (int i = 0; i < alpha.length(); i++)
{
char c = alpha.charAt(i);
if (unicode[c])
throw new IllegalArgumentException("Illegal alphabet: repeated character = '" + c + "'");
unicode[c] = true;
}

alphabet = alpha.toCharArray();
R = alpha.length();
inverse = new int[Character.MAX_VALUE];
for (int i = 0; i < inverse.length; i++)
inverse[i] = -1;

// can't use char since R can be as big as 65,536
for (int c = 0; c < R; c++)
inverse[alphabet[c]] = c;
}

/**
* Initializes a new alphabet using characters 0 through R-1.
*
* @param radix the number of characters in the alphabet (the radix R)
*/
private Alphabet(int radix)
{
this.R = radix;
alphabet = new char[R];
inverse = new int[R];

// can't use char since R can be as big as 65,536
for (int i = 0; i < R; i++)
alphabet[i] = (char) i;
for (int i = 0; i < R; i++)
inverse[i] = i;
}

/**
* Initializes a new alphabet using characters 0 through 255.
*/
public Alphabet()
{
this(256);
}

/**
* Returns true if the argument is a character in this alphabet.
*
* @param c the character
* @return {@code true} if {@code c} is a character in this alphabet;
* {@code false} otherwise
*/
public boolean contains(char c)
{
return inverse[c] != -1;
}

/**
* Returns the number of characters in this alphabet (the radix).
*
* @return the number of characters in this alphabet
* @deprecated Replaced by {@link #radix()}.
*/
@Deprecated
public int R()
{
return R;
}

/**
* Returns the number of characters in this alphabet (the radix).
*
* @return the number of characters in this alphabet
*/
public int radix()
{
return R;
}

/**
* Returns the binary logarithm of the number of characters in this alphabet.
*
* @return the binary logarithm (rounded up) of the number of characters in this alphabet
*/
public int lgR()
{
int lgR = 0;
for (int t = R - 1; t >= 1; t /= 2)
lgR++;
return lgR;
}

/**
* Returns the index corresponding to the argument character.
*
* @param c the character
* @return the index corresponding to the character {@code c}
* @throws IllegalArgumentException unless {@code c} is a character in this alphabet
*/
public int toIndex(char c)
{
if (c >= inverse.length || inverse[c] == -1)
{
throw new IllegalArgumentException("Character " + c + " not in alphabet");
}
return inverse[c];
}

/**
* Returns the indices corresponding to the argument characters.
*
* @param s the characters
* @return the indices corresponding to the characters {@code s}
* @throws IllegalArgumentException unless every character in {@code s}
* is a character in this alphabet
*/
public int[] toIndices(String s)
{
char[] source = s.toCharArray();
int[] target = new int[s.length()];
for (int i = 0; i < source.length; i++)
target[i] = toIndex(source[i]);
return target;
}

/**
* Returns the character corresponding to the argument index.
*
* @param index the index
* @return the character corresponding to the index {@code index}
* @throws IllegalArgumentException unless {@code 0 <= index < R}
*/
public char toChar(int index)
{
if (index < 0 || index >= R)
{
throw new IllegalArgumentException("index must be between 0 and " + R + ": " + index);
}
return alphabet[index];
}

/**
* Returns the characters corresponding to the argument indices.
*
* @param indices the indices
* @return the characters corresponding to the indices {@code indices}
* @throws IllegalArgumentException unless {@code 0 < indices[i] < R}
* for every {@code i}
*/
public String toChars(int[] indices)
{
StringBuilder s = new StringBuilder(indices.length);
for (int i = 0; i < indices.length; i++)
s.append(toChar(indices[i]));
return s.toString();
}

/**
* Unit tests the {@code Alphabet} data type.
*
* @param args the command-line arguments
*/
public static void main(String[] args)
{
int[] encoded1 = Alphabet.BASE64.toIndices("NowIsTheTimeForAllGoodMen");
String decoded1 = Alphabet.BASE64.toChars(encoded1);
StdOut.println(decoded1);

int[] encoded2 = Alphabet.DNA.toIndices("AACGAACGGTTTACCCCG");
String decoded2 = Alphabet.DNA.toChars(encoded2);
StdOut.println(decoded2);

int[] encoded3 = Alphabet.DECIMAL.toIndices("01234567890123456789");
String decoded3 = Alphabet.DECIMAL.toChars(encoded3);
StdOut.println(decoded3);
}
}

低位优先的字符串排序

各字符串长度需相等

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
public class LSD 
{
private static final int BITS_PER_BYTE = 8;

// do not instantiate
private LSD() { }

//Rearranges the array of w-character strings in ascending order.
public static void sort(String[] a, int w)
{
int n = a.length;
int R = 256; // extend ASCII alphabet size
String[] aux = new String[n];

for (int d = w - 1; d >= 0; d--) // sort by key-indexed counting on dth character
{
// compute frequency counts
int[] count = new int[R + 1];
for (int i = 0; i < n; i++)
count[a[i].charAt(d) + 1]++;

// compute cumulates
for (int r = 0; r < R; r++)
count[r + 1] += count[r];

// move data
for (int i = 0; i < n; i++)
aux[count[a[i].charAt(d)]++] = a[i];

// copy back
for (int i = 0; i < n; i++)
a[i] = aux[i];
}
}

/**
* Reads in a sequence of fixed-length strings from standard input;
* LSD radix sorts them;
* and prints them to standard output in ascending order.
*
* @param args the command-line arguments
*/
public static void main(String[] args)
{
String[] a = StdIn.readAllStrings();
int n = a.length;

// check that strings have fixed length
int w = a[0].length();
for (int i = 0; i < n; i++)
assert a[i].length() == w : "Strings must have fixed length";

// sort the strings
sort(a, w);

// print results
for (int i = 0; i < n; i++)
StdOut.println(a[i]);
}
}

高位优先的字符串排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
public class MSD 
{
private static final int BITS_PER_BYTE = 8;
private static final int BITS_PER_INT = 32; // each Java int is 32 bits
private static final int R = 256; // extended ASCII alphabet size
private static final int CUTOFF = 15; // cutoff to insertion sort

// do not instantiate
private MSD() { }

//Rearranges the array of extended ASCII strings in ascending order.
public static void sort(String[] a)
{
int n = a.length;
String[] aux = new String[n];
sort(a, 0, n-1, 0, aux);
}

// return dth character of s, -1 if d = length of string
private static int charAt(String s, int d)
{
assert d >= 0 && d <= s.length();
if (d == s.length()) return -1;
return s.charAt(d);
}

// sort from a[lo] to a[hi], starting at the dth character
private static void sort(String[] a, int lo, int hi, int d, String[] aux)
{

// cutoff to insertion sort for small subarrays
if (hi <= lo + CUTOFF)
{
insertion(a, lo, hi, d);
return;
}

// compute frequency counts
int[] count = new int[R + 2];
for (int i = lo; i <= hi; i++)
{
int c = charAt(a[i], d);
count[c+2]++;
}

// transform counts to indicies
for (int r = 0; r < R + 1; r++)
count[r + 1] += count[r];

// distribute
for (int i = lo; i <= hi; i++)
{
int c = charAt(a[i], d);
aux[count[c + 1]++] = a[i];
}

// copy back
for (int i = lo; i <= hi; i++)
a[i] = aux[i - lo];


// recursively sort for each character (excludes sentinel -1)
for (int r = 0; r < R; r++)
sort(a, lo + count[r], lo + count[r + 1] - 1, d+1, aux);
}


// insertion sort a[lo..hi], starting at dth character
private static void insertion(String[] a, int lo, int hi, int d)
{
for (int i = lo; i <= hi; i++)
for (int j = i; j > lo && less(a[j], a[j - 1], d); j--)
exch(a, j, j - 1);
}

// exchange a[i] and a[j]
private static void exch(String[] a, int i, int j)
{
String temp = a[i];
a[i] = a[j];
a[j] = temp;
}

// is v less than w, starting at character d
private static boolean less(String v, String w, int d)
{
// assert v.substring(0, d).equals(w.substring(0, d));
for (int i = d; i < Math.min(v.length(), w.length()); i++)
{
if (v.charAt(i) < w.charAt(i)) return true;
if (v.charAt(i) > w.charAt(i)) return false;
}
return v.length() < w.length();
}


/**
* Rearranges the array of 32-bit integers in ascending order.
* Currently assumes that the integers are nonnegative.
*
* @param a the array to be sorted
*/
public static void sort(int[] a)
{
int n = a.length;
int[] aux = new int[n];
sort(a, 0, n - 1, 0, aux);
}

// MSD sort from a[lo] to a[hi], starting at the dth byte
private static void sort(int[] a, int lo, int hi, int d, int[] aux)
{

// cutoff to insertion sort for small subarrays
if (hi <= lo + CUTOFF)
{
insertion(a, lo, hi, d);
return;
}

// compute frequency counts (need R = 256)
int[] count = new int[R + 1];
int mask = R - 1; // 0xFF;
int shift = BITS_PER_INT - BITS_PER_BYTE * d - BITS_PER_BYTE;
for (int i = lo; i <= hi; i++) {
int c = (a[i] >> shift) & mask;
count[c + 1]++;
}

// transform counts to indicies
for (int r = 0; r < R; r++)
count[r+1] += count[r];

/************* BUGGGY CODE.
// for most significant byte, 0x80-0xFF comes before 0x00-0x7F
if (d == 0)
{
int shift1 = count[R] - count[R / 2];
int shift2 = count[R/2];
for (int r = 0; r < R / 2; r++)
count[r] += shift1;
for (int r = R / 2; r < R; r++)
count[r] -= shift2;
}
************************************/
// distribute
for (int i = lo; i <= hi; i++)
{
int c = (a[i] >> shift) & mask;
aux[count[c]++] = a[i];
}

// copy back
for (int i = lo; i <= hi; i++)
a[i] = aux[i - lo];

// no more bits
if (d == 4) return;

// recursively sort for each character
if (count[0] > 0)
sort(a, lo, lo + count[0] - 1, d + 1, aux);
for (int r = 0; r < R; r++)
if (count[r + 1] > count[r])
sort(a, lo + count[r], lo + count[r + 1] - 1, d+1, aux);
}

// TODO: insertion sort a[lo..hi], starting at dth character
private static void insertion(int[] a, int lo, int hi, int d)
{
for (int i = lo; i <= hi; i++)
for (int j = i; j > lo && a[j] < a[j - 1]; j--)
exch(a, j, j - 1);
}

// exchange a[i] and a[j]
private static void exch(int[] a, int i, int j)
{
int temp = a[i];
a[i] = a[j];
a[j] = temp;
}


/**
* Reads in a sequence of extended ASCII strings from standard input;
* MSD radix sorts them;
* and prints them to standard output in ascending order.
*
* @param args the command-line arguments
*/
public static void main(String[] args)
{
String[] a = StdIn.readAllStrings();
int n = a.length;
sort(a);
for (int i = 0; i < n; i++)
StdOut.println(a[i]);
}
}
小型子数组

高位优先的字符串排序的基本思想是很有效的:在一般的应用中,只需要检查若干个字符就能完成所有字符串的排序。换句话说,这种方法能够快速地将需要排序的数组切分为较小的数组。但这种切分也是一把双刃剑:我们肯定会需要处理大量微型数组,因此必须快速处理它们。小型子数组对于高位优先的字符串排序的性能至关重要。我们在其他递归排序算法中也遇到过这种情况(快速排序和归并排序),但小数组对于高位优先的字符串排序的影响尤其强烈。将小数组切换到插入排序对于高位优先的字符串排序算法是必须的。

等值键

高位优先的字符串排序中的第二个陷阱是,对于含有大量等值键的子数组的排序会较慢。如果相同的子字符串出现得过多,切换排序方法条件将不会出现,那么递归方法就会检查所有相同键中的每一个字符。另外,键索引计数法无法有效判断字符串中的字符是否全部相同:它不仅需要检查每个字符和移动每个字符串,还需要初始化所有的频率统计并将它们转换为索引等。因此,高位优先的字符串排序的最坏情况就是所有的键均相同。大量含有相同前缀的键也会产生同样的问题,这在一般的应用场景中是很常见的。

三向字符串快速排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
public class Quick3string 
{
private static final int CUTOFF = 15; // cutoff to insertion sort

// do not instantiate
private Quick3string() { }

//Rearranges the array of strings in ascending order.
public static void sort(String[] a)
{
StdRandom.shuffle(a);
sort(a, 0, a.length-1, 0);
assert isSorted(a);
}

// return the dth character of s, -1 if d = length of s
private static int charAt(String s, int d)
{
assert d >= 0 && d <= s.length();
if (d == s.length()) return -1;
return s.charAt(d);
}


// 3-way string quicksort a[lo..hi] starting at dth character
private static void sort(String[] a, int lo, int hi, int d)
{
// cutoff to insertion sort for small subarrays
if (hi <= lo + CUTOFF)
{
insertion(a, lo, hi, d);
return;
}

int lt = lo, gt = hi;
int v = charAt(a[lo], d);
int i = lo + 1;
while (i <= gt)
{
int t = charAt(a[i], d);
if (t < v) exch(a, lt++, i++);
else if (t > v) exch(a, i, gt--);
else i++;
}

// a[lo..lt-1] < v = a[lt..gt] < a[gt+1..hi].
sort(a, lo, lt - 1, d);
if (v >= 0) sort(a, lt, gt, d + 1);
sort(a, gt + 1, hi, d);
}

// sort from a[lo] to a[hi], starting at the dth character
private static void insertion(String[] a, int lo, int hi, int d)
{
for (int i = lo; i <= hi; i++)
for (int j = i; j > lo && less(a[j], a[j - 1], d); j--)
exch(a, j, j-1);
}

// exchange a[i] and a[j]
private static void exch(String[] a, int i, int j)
{
String temp = a[i];
a[i] = a[j];
a[j] = temp;
}

// is v less than w, starting at character d
// DEPRECATED BECAUSE OF SLOW SUBSTRING EXTRACTION IN JAVA 7
// private static boolean less(String v, String w, int d) {
// assert v.substring(0, d).equals(w.substring(0, d));
// return v.substring(d).compareTo(w.substring(d)) < 0;
// }

// is v less than w, starting at character d
private static boolean less(String v, String w, int d)
{
assert v.substring(0, d).equals(w.substring(0, d));
for (int i = d; i < Math.min(v.length(), w.length()); i++)
{
if (v.charAt(i) < w.charAt(i)) return true;
if (v.charAt(i) > w.charAt(i)) return false;
}
return v.length() < w.length();
}

// is the array sorted
private static boolean isSorted(String[] a)
{
for (int i = 1; i < a.length; i++)
if (a[i].compareTo(a[i - 1]) < 0) return false;
return true;
}


/**
* Reads in a sequence of fixed-length strings from standard input;
* 3-way radix quicksorts them;
* and prints them to standard output in ascending order.
*
* @param args the command-line arguments
*/
public static void main(String[] args)
{

// read in the strings from standard input
String[] a = StdIn.readAllStrings();
int n = a.length;

// sort the strings
sort(a);

// print the results
for (int i = 0; i < n; i++)
StdOut.println(a[i]);
}
}

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
*  % java Quick3string < shell.txt
* are
* by
* sea
* seashells
* seashells
* sells
* sells
* she
* she
* shells
* shore
* surely
* the
* the

单词查找树

基于单词查找树的符号表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
public class TrieST<Value> 
{
private static final int R = 256; // extended ASCII

private Node root; // root of trie
private int n; // number of keys in trie

// R-way trie node
private static class Node
{
private Object val;
private Node[] next = new Node[R];
}

//Initializes an empty string symbol table.
public TrieST() {}


//Returns the value associated with the given key.
public Value get(String key)
{
if (key == null) throw new IllegalArgumentException("argument to get() is null");
Node x = get(root, key, 0);
if (x == null) return null;
return (Value) x.val;
}

//Does this symbol table contain the given key?
public boolean contains(String key)
{
if (key == null) throw new IllegalArgumentException("argument to contains() is null");
return get(key) != null;
}

private Node get(Node x, String key, int d)
{
if (x == null) return null;
if (d == key.length()) return x;
char c = key.charAt(d);
return get(x.next[c], key, d + 1);
}

/**
* Inserts the key-value pair into the symbol table, overwriting the old value
* with the new value if the key is already in the symbol table.
* If the value is {@code null}, this effectively deletes the key from the symbol table.
*/
public void put(String key, Value val)
{
if (key == null) throw new IllegalArgumentException("first argument to put() is null");
if (val == null) delete(key);
else root = put(root, key, val, 0);
}

private Node put(Node x, String key, Value val, int d)
{
if (x == null) x = new Node();
if (d == key.length())
{
if (x.val == null) n++;
x.val = val;
return x;
}
char c = key.charAt(d);
x.next[c] = put(x.next[c], key, val, d + 1);
return x;
}

//Returns the number of key-value pairs in this symbol table.
public int size()
{
return n;
}

//Is this symbol table empty?
public boolean isEmpty()
{
return size() == 0;
}

//Returns all keys in the symbol table as an {@code Iterable}.
public Iterable<String> keys()
{
return keysWithPrefix("");
}

//Returns all of the keys in the set that start with {@code prefix}.
public Iterable<String> keysWithPrefix(String prefix)
{
Queue<String> results = new Queue<String>();
Node x = get(root, prefix, 0);
collect(x, new StringBuilder(prefix), results);
return results;
}

private void collect(Node x, StringBuilder prefix, Queue<String> results)
{
if (x == null) return;
if (x.val != null) results.enqueue(prefix.toString());
for (char c = 0; c < R; c++) {
prefix.append(c);
collect(x.next[c], prefix, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}

/**
* Returns all of the keys in the symbol table that match {@code pattern},
* where . symbol is treated as a wildcard character.
*/
public Iterable<String> keysThatMatch(String pattern)
{
Queue<String> results = new Queue<String>();
collect(root, new StringBuilder(), pattern, results);
return results;
}

private void collect(Node x, StringBuilder prefix, String pattern, Queue<String> results)
{
if (x == null) return;
int d = prefix.length();
if (d == pattern.length() && x.val != null)
results.enqueue(prefix.toString());
if (d == pattern.length())
return;
char c = pattern.charAt(d);
if (c == '.')
{
for (char ch = 0; ch < R; ch++)
{
prefix.append(ch);
collect(x.next[ch], prefix, pattern, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}
else
{
prefix.append(c);
collect(x.next[c], prefix, pattern, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}

/**
* Returns the string in the symbol table that is the longest prefix of {@code query},
* or {@code null}, if no such string.
*/
public String longestPrefixOf(String query)
{
if (query == null) throw new IllegalArgumentException("argument to longestPrefixOf() is null");
int length = longestPrefixOf(root, query, 0, -1);
if (length == -1) return null;
else return query.substring(0, length);
}

// returns the length of the longest string key in the subtrie
// rooted at x that is a prefix of the query string,
// assuming the first d character match and we have already
// found a prefix match of given length (-1 if no such match)
private int longestPrefixOf(Node x, String query, int d, int length)
{
if (x == null) return length;
if (x.val != null) length = d;
if (d == query.length()) return length;
char c = query.charAt(d);
return longestPrefixOf(x.next[c], query, d + 1, length);
}

//Removes the key from the set if the key is present.
public void delete(String key)
{
if (key == null) throw new IllegalArgumentException("argument to delete() is null");
root = delete(root, key, 0);
}

private Node delete(Node x, String key, int d)
{
if (x == null) return null;
if (d == key.length())
{
if (x.val != null) n--;
x.val = null;
}
else
{
char c = key.charAt(d);
x.next[c] = delete(x.next[c], key, d+1);
}

// remove subtrie rooted at x if it is completely empty
if (x.val != null) return x;
for (int c = 0; c < R; c++)
if (x.next[c] != null)
return x;
return null;
}

/**
* Unit tests the {@code TrieST} data type.
*
* @param args the command-line arguments
*/
public static void main(String[] args)
{
// build symbol table from standard input
TrieST<Integer> st = new TrieST<Integer>();
for (int i = 0; !StdIn.isEmpty(); i++)
{
String key = StdIn.readString();
st.put(key, i);
}

// print results
if (st.size() < 100) {
StdOut.println("keys(\"\"):");
for (String key : st.keys())
{
StdOut.println(key + " " + st.get(key));
}
StdOut.println();
}

StdOut.println("longestPrefixOf(\"shellsort\"):");
StdOut.println(st.longestPrefixOf("shellsort"));
StdOut.println();

StdOut.println("longestPrefixOf(\"quicksort\"):");
StdOut.println(st.longestPrefixOf("quicksort"));
StdOut.println();

StdOut.println("keysWithPrefix(\"shor\"):");
for (String s : st.keysWithPrefix("shor"))
StdOut.println(s);
StdOut.println();

StdOut.println("keysThatMatch(\".he.l.\"):");
for (String s : st.keysThatMatch(".he.l."))
StdOut.println(s);
}
}

基于三向单词查找树的符号表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
public class TST<Value> 
{
private int n; // size
private Node<Value> root; // root of TST

private static class Node<Value>
{
private char c; // character
private Node<Value> left, mid, right; // left, middle, and right subtries
private Value val; // value associated with string
}

//Initializes an empty string symbol table.
public TST() {}

//Returns the number of key-value pairs in this symbol table.
public int size()
{
return n;
}

//Does this symbol table contain the given key?
public boolean contains(String key)
{
if (key == null)
{
throw new IllegalArgumentException("argument to contains() is null");
}
return get(key) != null;
}

//Returns the value associated with the given key.
public Value get(String key)
{
if (key == null)
{
throw new IllegalArgumentException("calls get() with null argument");
}
if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
Node<Value> x = get(root, key, 0);
if (x == null) return null;
return x.val;
}

// return subtrie corresponding to given key
private Node<Value> get(Node<Value> x, String key, int d)
{
if (x == null) return null;
if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
char c = key.charAt(d);
if (c < x.c) return get(x.left, key, d);
else if (c > x.c) return get(x.right, key, d);
else if (d < key.length() - 1) return get(x.mid, key, d + 1);
else return x;
}

/**
* Inserts the key-value pair into the symbol table, overwriting the old value
* with the new value if the key is already in the symbol table.
* If the value is {@code null}, this effectively deletes the key from the symbol table.
*/
public void put(String key, Value val)
{
if (key == null) {
throw new IllegalArgumentException("calls put() with null key");
}
if (!contains(key)) n++;
else if (val == null) n--; // delete existing key
root = put(root, key, val, 0);
}

private Node<Value> put(Node<Value> x, String key, Value val, int d)
{
char c = key.charAt(d);
if (x == null)
{
x = new Node<Value>();
x.c = c;
}
if (c < x.c) x.left = put(x.left, key, val, d);
else if (c > x.c) x.right = put(x.right, key, val, d);
else if (d < key.length() - 1) x.mid = put(x.mid, key, val, d + 1);
else x.val = val;
return x;
}

/**
* Returns the string in the symbol table that is the longest prefix of {@code query},
* or {@code null}, if no such string.
*/
public String longestPrefixOf(String query)
{
if (query == null)
{
throw new IllegalArgumentException("calls longestPrefixOf() with null argument");
}
if (query.length() == 0) return null;
int length = 0;
Node<Value> x = root;
int i = 0;
while (x != null && i < query.length())
{
char c = query.charAt(i);
if (c < x.c) x = x.left;
else if (c > x.c) x = x.right;
else
{
i++;
if (x.val != null) length = i;
x = x.mid;
}
}
return query.substring(0, length);
}

//Returns all keys in the symbol table as an {@code Iterable}.
public Iterable<String> keys()
{
Queue<String> queue = new Queue<String>();
collect(root, new StringBuilder(), queue);
return queue;
}

//Returns all of the keys in the set that start with {@code prefix}.
public Iterable<String> keysWithPrefix(String prefix)
{
if (prefix == null)
{
throw new IllegalArgumentException("calls keysWithPrefix() with null argument");
}
Queue<String> queue = new Queue<String>();
Node<Value> x = get(root, prefix, 0);
if (x == null) return queue;
if (x.val != null) queue.enqueue(prefix);
collect(x.mid, new StringBuilder(prefix), queue);
return queue;
}

// all keys in subtrie rooted at x with given prefix
private void collect(Node<Value> x, StringBuilder prefix, Queue<String> queue)
{
if (x == null) return;
collect(x.left, prefix, queue);
if (x.val != null) queue.enqueue(prefix.toString() + x.c);
collect(x.mid, prefix.append(x.c), queue);
prefix.deleteCharAt(prefix.length() - 1);
collect(x.right, prefix, queue);
}


/**
* Returns all of the keys in the symbol table that match {@code pattern},
* where . symbol is treated as a wildcard character.
*/
public Iterable<String> keysThatMatch(String pattern)
{
Queue<String> queue = new Queue<String>();
collect(root, new StringBuilder(), 0, pattern, queue);
return queue;
}

private void collect(Node<Value> x, StringBuilder prefix, int i, String pattern, Queue<String> queue)
{
if (x == null) return;
char c = pattern.charAt(i);
if (c == '.' || c < x.c) collect(x.left, prefix, i, pattern, queue);
if (c == '.' || c == x.c)
{
if (i == pattern.length() - 1 && x.val != null) queue.enqueue(prefix.toString() + x.c);
if (i < pattern.length() - 1)
{
collect(x.mid, prefix.append(x.c), i + 1, pattern, queue);
prefix.deleteCharAt(prefix.length() - 1);
}
}
if (c == '.' || c > x.c) collect(x.right, prefix, i, pattern, queue);
}


/**
* Unit tests the {@code TST} data type.
*
* @param args the command-line arguments
*/
public static void main(String[] args)
{
// build symbol table from standard input
TST<Integer> st = new TST<Integer>();
for (int i = 0; !StdIn.isEmpty(); i++)
{
String key = StdIn.readString();
st.put(key, i);
}

// print results
if (st.size() < 100)
{
StdOut.println("keys(\"\"):");
for (String key : st.keys())
{
StdOut.println(key + " " + st.get(key));
}
StdOut.println();
}

StdOut.println("longestPrefixOf(\"shellsort\"):");
StdOut.println(st.longestPrefixOf("shellsort"));
StdOut.println();

StdOut.println("longestPrefixOf(\"shell\"):");
StdOut.println(st.longestPrefixOf("shell"));
StdOut.println();

StdOut.println("keysWithPrefix(\"shor\"):");
for (String s : st.keysWithPrefix("shor"))
StdOut.println(s);
StdOut.println();

StdOut.println("keysThatMatch(\".he.l.\"):");
for (String s : st.keysThatMatch(".he.l."))
StdOut.println(s);
}
}

子字符串查找

断更