|
1 #!/usr/bin/env python |
|
2 # |
|
3 # Copyright 2007 Google Inc. |
|
4 # |
|
5 # Licensed under the Apache License, Version 2.0 (the "License"); |
|
6 # you may not use this file except in compliance with the License. |
|
7 # You may obtain a copy of the License at |
|
8 # |
|
9 # http://www.apache.org/licenses/LICENSE-2.0 |
|
10 # |
|
11 # Unless required by applicable law or agreed to in writing, software |
|
12 # distributed under the License is distributed on an "AS IS" BASIS, |
|
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
14 # See the License for the specific language governing permissions and |
|
15 # limitations under the License. |
|
16 # |
|
17 |
|
18 """Key range representation and splitting.""" |
|
19 |
|
20 |
|
21 import os |
|
22 |
|
23 try: |
|
24 import simplejson |
|
25 except ImportError: |
|
26 simplejson = None |
|
27 |
|
28 from google.appengine.api import datastore |
|
29 from google.appengine.datastore import datastore_pb |
|
30 from google.appengine.ext import db |
|
31 |
|
32 |
|
33 class Error(Exception): |
|
34 """Base class for exceptions in this module.""" |
|
35 |
|
36 |
|
37 class KeyRangeError(Error): |
|
38 """Error while trying to generate a KeyRange.""" |
|
39 |
|
40 |
|
41 class SimplejsonUnavailableError(Error): |
|
42 """Error while using json functionality whith unavailable simplejson.""" |
|
43 |
|
44 class EmptyDbQuery(db.Query): |
|
45 """A query that returns no results.""" |
|
46 |
|
47 def get(self): |
|
48 return None |
|
49 |
|
50 def fetch(self, limit=1000, offset=0): |
|
51 return [] |
|
52 |
|
53 def count(self, limit=1000): |
|
54 return 0 |
|
55 |
|
56 |
|
57 class EmptyDatastoreQuery(datastore.Query): |
|
58 """A query that returns no results.""" |
|
59 |
|
60 def __init__(self, kind): |
|
61 datastore.Query.__init__(self, kind) |
|
62 |
|
63 def _Run(self, *unused_args, **unused_kwargs): |
|
64 empty_result_pb = datastore_pb.QueryResult() |
|
65 empty_result_pb.set_cursor(0) |
|
66 empty_result_pb.set_more_results(False) |
|
67 return datastore.Iterator(empty_result_pb) |
|
68 |
|
69 def Count(self, *unused_args, **unused_kwargs): |
|
70 return 0 |
|
71 |
|
72 def Get(self, *unused_args, **unused_kwargs): |
|
73 return [] |
|
74 |
|
75 def Next(self, *unused_args, **unused_kwargs): |
|
76 return [] |
|
77 |
|
78 |
|
79 class KeyRange(object): |
|
80 """Represents a range of keys in the datastore. |
|
81 |
|
82 A KeyRange object represents a key range |
|
83 (key_start, include_start, key_end, include_end) |
|
84 and a scan direction (KeyRange.DESC or KeyRange.ASC). |
|
85 """ |
|
86 |
|
87 DESC = 'DESC' |
|
88 ASC = 'ASC' |
|
89 |
|
90 def __init__(self, |
|
91 key_start=None, |
|
92 key_end=None, |
|
93 direction=None, |
|
94 include_start=True, |
|
95 include_end=True): |
|
96 """Initialize a KeyRange object. |
|
97 |
|
98 Args: |
|
99 key_start: The starting key for this range. |
|
100 key_end: The ending key for this range. |
|
101 direction: The direction of the query for this range. |
|
102 include_start: Whether the start key should be included in the range. |
|
103 include_end: Whether the end key should be included in the range. |
|
104 """ |
|
105 if direction is None: |
|
106 direction = KeyRange.ASC |
|
107 assert direction in (KeyRange.ASC, KeyRange.DESC) |
|
108 self.direction = direction |
|
109 self.key_start = key_start |
|
110 self.key_end = key_end |
|
111 self.include_start = include_start |
|
112 self.include_end = include_end |
|
113 |
|
114 def __str__(self): |
|
115 if self.include_start: |
|
116 left_side = '[' |
|
117 else: |
|
118 left_side = '(' |
|
119 if self.include_end: |
|
120 right_side = ']' |
|
121 else: |
|
122 right_side = '(' |
|
123 return '%s%s%s-%s%s' % (self.direction, left_side, repr(self.key_start), |
|
124 repr(self.key_end), right_side) |
|
125 |
|
126 def __repr__(self): |
|
127 return ('key_range.KeyRange(key_start=%s,key_end=%s,direction=%s,' |
|
128 'include_start=%s,include_end=%s)') % (repr(self.key_start), |
|
129 repr(self.key_end), |
|
130 repr(self.direction), |
|
131 repr(self.include_start), |
|
132 repr(self.include_end)) |
|
133 |
|
134 def filter_query(self, query): |
|
135 """Add query filter to restrict to this key range. |
|
136 |
|
137 Args: |
|
138 query: A db.Query instance. |
|
139 |
|
140 Returns: |
|
141 The input query restricted to this key range or an empty query if |
|
142 this key range is empty. |
|
143 """ |
|
144 assert isinstance(query, db.Query) |
|
145 if self.key_start == self.key_end and not ( |
|
146 self.include_start or self.include_end): |
|
147 return EmptyDbQuery() |
|
148 if self.include_start: |
|
149 start_comparator = '>=' |
|
150 else: |
|
151 start_comparator = '>' |
|
152 if self.include_end: |
|
153 end_comparator = '<=' |
|
154 else: |
|
155 end_comparator = '<' |
|
156 if self.key_start: |
|
157 query.filter('__key__ %s' % start_comparator, self.key_start) |
|
158 if self.key_end: |
|
159 query.filter('__key__ %s' % end_comparator, self.key_end) |
|
160 return query |
|
161 |
|
162 def filter_datastore_query(self, query): |
|
163 """Add query filter to restrict to this key range. |
|
164 |
|
165 Args: |
|
166 query: A datastore.Query instance. |
|
167 |
|
168 Returns: |
|
169 The input query restricted to this key range or an empty query if |
|
170 this key range is empty. |
|
171 """ |
|
172 assert isinstance(query, datastore.Query) |
|
173 if self.key_start == self.key_end and not ( |
|
174 self.include_start or self.include_end): |
|
175 return EmptyDatastoreQuery(query.kind) |
|
176 if self.include_start: |
|
177 start_comparator = '>=' |
|
178 else: |
|
179 start_comparator = '>' |
|
180 if self.include_end: |
|
181 end_comparator = '<=' |
|
182 else: |
|
183 end_comparator = '<' |
|
184 if self.key_start: |
|
185 query.update({'__key__ %s' % start_comparator: self.key_start}) |
|
186 if self.key_end: |
|
187 query.update({'__key__ %s' % end_comparator: self.key_end}) |
|
188 return query |
|
189 |
|
190 def __get_direction(self, asc, desc): |
|
191 """Check that self.direction is in (KeyRange.ASC, KeyRange.DESC). |
|
192 |
|
193 Args: |
|
194 asc: Argument to return if self.direction is KeyRange.ASC |
|
195 desc: Argument to return if self.direction is KeyRange.DESC |
|
196 |
|
197 Returns: |
|
198 asc or desc appropriately |
|
199 |
|
200 Raises: |
|
201 KeyRangeError: if self.direction is not in (KeyRange.ASC, KeyRange.DESC). |
|
202 """ |
|
203 if self.direction == KeyRange.ASC: |
|
204 return asc |
|
205 elif self.direction == KeyRange.DESC: |
|
206 return desc |
|
207 else: |
|
208 raise KeyRangeError('KeyRange direction unexpected: %s', self.direction) |
|
209 |
|
210 def make_directed_query(self, kind_class): |
|
211 """Construct a query for this key range, including the scan direction. |
|
212 |
|
213 Args: |
|
214 kind_class: A kind implementation class. |
|
215 |
|
216 Returns: |
|
217 A db.Query instance. |
|
218 |
|
219 Raises: |
|
220 KeyRangeError: if self.direction is not in (KeyRange.ASC, KeyRange.DESC). |
|
221 """ |
|
222 direction = self.__get_direction('', '-') |
|
223 query = db.Query(kind_class) |
|
224 query.order('%s__key__' % direction) |
|
225 |
|
226 query = self.filter_query(query) |
|
227 return query |
|
228 |
|
229 def make_directed_datastore_query(self, kind): |
|
230 """Construct a query for this key range, including the scan direction. |
|
231 |
|
232 Args: |
|
233 kind: A string. |
|
234 |
|
235 Returns: |
|
236 A datastore.Query instance. |
|
237 |
|
238 Raises: |
|
239 KeyRangeError: if self.direction is not in (KeyRange.ASC, KeyRange.DESC). |
|
240 """ |
|
241 direction = self.__get_direction(datastore.Query.ASCENDING, |
|
242 datastore.Query.DESCENDING) |
|
243 query = datastore.Query(kind) |
|
244 query.Order(('__key__', direction)) |
|
245 |
|
246 query = self.filter_datastore_query(query) |
|
247 return query |
|
248 |
|
249 def make_ascending_query(self, kind_class): |
|
250 """Construct a query for this key range without setting the scan direction. |
|
251 |
|
252 Args: |
|
253 kind_class: A kind implementation class. |
|
254 |
|
255 Returns: |
|
256 A db.Query instance. |
|
257 """ |
|
258 query = db.Query(kind_class) |
|
259 query.order('__key__') |
|
260 |
|
261 query = self.filter_query(query) |
|
262 return query |
|
263 |
|
264 def make_ascending_datastore_query(self, kind): |
|
265 """Construct a query for this key range without setting the scan direction. |
|
266 |
|
267 Args: |
|
268 kind: A string. |
|
269 |
|
270 Returns: |
|
271 A datastore.Query instance. |
|
272 """ |
|
273 query = datastore.Query(kind) |
|
274 query.Order(('__key__', datastore.Query.ASCENDING)) |
|
275 |
|
276 query = self.filter_datastore_query(query) |
|
277 return query |
|
278 |
|
279 def split_range(self, batch_size=0): |
|
280 """Split this key range into a list of at most two ranges. |
|
281 |
|
282 This method attempts to split the key range approximately in half. |
|
283 Numeric ranges are split in the middle into two equal ranges and |
|
284 string ranges are split lexicographically in the middle. If the |
|
285 key range is smaller than batch_size it is left unsplit. |
|
286 |
|
287 Note that splitting is done without knowledge of the distribution |
|
288 of actual entities in the key range, so there is no guarantee (nor |
|
289 any particular reason to believe) that the entities of the range |
|
290 are evenly split. |
|
291 |
|
292 Args: |
|
293 batch_size: The maximum size of a key range that should not be split. |
|
294 |
|
295 Returns: |
|
296 A list of one or two key ranges covering the same space as this range. |
|
297 """ |
|
298 key_start = self.key_start |
|
299 key_end = self.key_end |
|
300 include_start = self.include_start |
|
301 include_end = self.include_end |
|
302 |
|
303 key_pairs = [] |
|
304 if not key_start: |
|
305 key_pairs.append((key_start, include_start, key_end, include_end, |
|
306 KeyRange.ASC)) |
|
307 elif not key_end: |
|
308 key_pairs.append((key_start, include_start, key_end, include_end, |
|
309 KeyRange.DESC)) |
|
310 else: |
|
311 key_split = KeyRange.split_keys(key_start, key_end, batch_size) |
|
312 first_include_end = True |
|
313 if key_split == key_start: |
|
314 first_include_end = first_include_end and include_start |
|
315 |
|
316 key_pairs.append((key_start, include_start, |
|
317 key_split, first_include_end, |
|
318 KeyRange.DESC)) |
|
319 |
|
320 second_include_end = include_end |
|
321 if key_split == key_end: |
|
322 second_include_end = False |
|
323 key_pairs.append((key_split, False, |
|
324 key_end, second_include_end, |
|
325 KeyRange.ASC)) |
|
326 |
|
327 ranges = [KeyRange(key_start=start, |
|
328 include_start=include_start, |
|
329 key_end=end, |
|
330 include_end=include_end, |
|
331 direction=direction) |
|
332 for (start, include_start, end, include_end, direction) |
|
333 in key_pairs] |
|
334 |
|
335 return ranges |
|
336 |
|
337 def __cmp__(self, other): |
|
338 """Compare two key ranges. |
|
339 |
|
340 Key ranges with a value of None for key_start or key_end, are always |
|
341 considered to have include_start=False or include_end=False, respectively, |
|
342 when comparing. Since None indicates an unbounded side of the range, |
|
343 the include specifier is meaningless. The ordering generated is total |
|
344 but somewhat arbitrary. |
|
345 |
|
346 Args: |
|
347 other: An object to compare to this one. |
|
348 |
|
349 Returns: |
|
350 -1: if this key range is less than other. |
|
351 0: if this key range is equal to other. |
|
352 1: if this key range is greater than other. |
|
353 """ |
|
354 if not isinstance(other, KeyRange): |
|
355 return 1 |
|
356 |
|
357 self_list = [self.key_start, self.key_end, self.direction, |
|
358 self.include_start, self.include_end] |
|
359 if not self.key_start: |
|
360 self_list[3] = False |
|
361 if not self.key_end: |
|
362 self_list[4] = False |
|
363 |
|
364 other_list = [other.key_start, |
|
365 other.key_end, |
|
366 other.direction, |
|
367 other.include_start, |
|
368 other.include_end] |
|
369 if not other.key_start: |
|
370 other_list[3] = False |
|
371 if not other.key_end: |
|
372 other_list[4] = False |
|
373 |
|
374 return cmp(self_list, other_list) |
|
375 |
|
376 @staticmethod |
|
377 def bisect_string_range(start, end): |
|
378 """Returns a string that is approximately in the middle of the range. |
|
379 |
|
380 (start, end) is treated as a string range, and it is assumed |
|
381 start <= end in the usual lexicographic string ordering. The output key |
|
382 mid is guaranteed to satisfy start <= mid <= end. |
|
383 |
|
384 The method proceeds by comparing initial characters of start and |
|
385 end. When the characters are equal, they are appended to the mid |
|
386 string. In the first place that the characters differ, the |
|
387 difference characters are averaged and this average is appended to |
|
388 the mid string. If averaging resulted in rounding down, and |
|
389 additional character is added to the mid string to make up for the |
|
390 rounding down. This extra step is necessary for correctness in |
|
391 the case that the average of the two characters is equal to the |
|
392 character in the start string. |
|
393 |
|
394 This method makes the assumption that most keys are ascii and it |
|
395 attempts to perform splitting within the ascii range when that |
|
396 results in a valid split. |
|
397 |
|
398 Args: |
|
399 start: A string. |
|
400 end: A string such that start <= end. |
|
401 |
|
402 Returns: |
|
403 A string mid such that start <= mid <= end. |
|
404 """ |
|
405 if start == end: |
|
406 return start |
|
407 start += '\0' |
|
408 end += '\0' |
|
409 midpoint = [] |
|
410 expected_max = 127 |
|
411 for i in xrange(min(len(start), len(end))): |
|
412 if start[i] == end[i]: |
|
413 midpoint.append(start[i]) |
|
414 else: |
|
415 ord_sum = ord(start[i]) + ord(end[i]) |
|
416 midpoint.append(unichr(ord_sum / 2)) |
|
417 if ord_sum % 2: |
|
418 if len(start) > i + 1: |
|
419 ord_start = ord(start[i+1]) |
|
420 else: |
|
421 ord_start = 0 |
|
422 if ord_start < expected_max: |
|
423 ord_split = (expected_max + ord_start) / 2 |
|
424 else: |
|
425 ord_split = (0xFFFF + ord_start) / 2 |
|
426 midpoint.append(unichr(ord_split)) |
|
427 break |
|
428 return ''.join(midpoint) |
|
429 |
|
430 @staticmethod |
|
431 def split_keys(key_start, key_end, batch_size): |
|
432 """Return a key that is between key_start and key_end inclusive. |
|
433 |
|
434 This method compares components of the ancestor paths of key_start |
|
435 and key_end. The first place in the path that differs is |
|
436 approximately split in half. If the kind components differ, a new |
|
437 non-existent kind halfway between the two is used to split the |
|
438 space. If the id_or_name components differ, then a new id_or_name |
|
439 that is halfway between the two is selected. If the lower |
|
440 id_or_name is numeric and the upper id_or_name is a string, then |
|
441 the minumum string key u'\0' is used as the split id_or_name. The |
|
442 key that is returned is the shared portion of the ancestor path |
|
443 followed by the generated split component. |
|
444 |
|
445 Args: |
|
446 key_start: A db.Key instance for the lower end of a range. |
|
447 key_end: A db.Key instance for the upper end of a range. |
|
448 batch_size: The maximum size of a range that should not be split. |
|
449 |
|
450 Returns: |
|
451 A db.Key instance, k, such that key_start <= k <= key_end. |
|
452 """ |
|
453 assert key_start.app() == key_end.app() |
|
454 path1 = key_start.to_path() |
|
455 path2 = key_end.to_path() |
|
456 len1 = len(path1) |
|
457 len2 = len(path2) |
|
458 assert len1 % 2 == 0 |
|
459 assert len2 % 2 == 0 |
|
460 out_path = [] |
|
461 min_path_len = min(len1, len2) / 2 |
|
462 for i in xrange(min_path_len): |
|
463 kind1 = path1[2*i] |
|
464 kind2 = path2[2*i] |
|
465 |
|
466 if kind1 != kind2: |
|
467 split_kind = KeyRange.bisect_string_range(kind1, kind2) |
|
468 out_path.append(split_kind) |
|
469 out_path.append(unichr(0)) |
|
470 break |
|
471 |
|
472 last = (len1 == len2 == 2*(i + 1)) |
|
473 |
|
474 id_or_name1 = path1[2*i + 1] |
|
475 id_or_name2 = path2[2*i + 1] |
|
476 id_or_name_split = KeyRange._split_id_or_name( |
|
477 id_or_name1, id_or_name2, batch_size, last) |
|
478 if id_or_name1 == id_or_name_split: |
|
479 out_path.append(kind1) |
|
480 out_path.append(id_or_name1) |
|
481 else: |
|
482 out_path.append(kind1) |
|
483 out_path.append(id_or_name_split) |
|
484 break |
|
485 |
|
486 return db.Key.from_path(*out_path) |
|
487 |
|
488 @staticmethod |
|
489 def _split_id_or_name(id_or_name1, id_or_name2, batch_size, maintain_batches): |
|
490 """Return an id_or_name that is between id_or_name1 an id_or_name2. |
|
491 |
|
492 Attempts to split the range [id_or_name1, id_or_name2] in half, |
|
493 unless maintain_batches is true and the size of the range |
|
494 [id_or_name1, id_or_name2] is less than or equal to batch_size. |
|
495 |
|
496 Args: |
|
497 id_or_name1: A number or string or the id_or_name component of a key |
|
498 id_or_name2: A number or string or the id_or_name component of a key |
|
499 batch_size: The range size that will not be split if maintain_batches |
|
500 is true. |
|
501 maintain_batches: A boolean for whether to keep small ranges intact. |
|
502 |
|
503 Returns: |
|
504 An id_or_name such that id_or_name1 <= id_or_name <= id_or_name2. |
|
505 """ |
|
506 if (isinstance(id_or_name1, (int, long)) and |
|
507 isinstance(id_or_name2, (int, long))): |
|
508 if not maintain_batches or id_or_name2 - id_or_name1 > batch_size: |
|
509 return (id_or_name1 + id_or_name2) / 2 |
|
510 else: |
|
511 return id_or_name1 |
|
512 elif (isinstance(id_or_name1, basestring) and |
|
513 isinstance(id_or_name2, basestring)): |
|
514 return KeyRange.bisect_string_range(id_or_name1, id_or_name2) |
|
515 else: |
|
516 assert (isinstance(id_or_name1, (int, long)) and |
|
517 isinstance(id_or_name2, basestring)) |
|
518 return unichr(0) |
|
519 |
|
520 def to_json(self): |
|
521 """Serialize KeyRange to json. |
|
522 |
|
523 Returns: |
|
524 string with KeyRange json representation. |
|
525 """ |
|
526 if simplejson is None: |
|
527 raise SimplejsonUnavailableError( |
|
528 "JSON functionality requires simplejson to be available") |
|
529 |
|
530 def key_to_str(key): |
|
531 if key: |
|
532 return str(key) |
|
533 else: |
|
534 return None |
|
535 |
|
536 return simplejson.dumps({ |
|
537 "direction": self.direction, |
|
538 "key_start": key_to_str(self.key_start), |
|
539 "key_end": key_to_str(self.key_end), |
|
540 "include_start": self.include_start, |
|
541 "include_end": self.include_end, |
|
542 }, sort_keys=True) |
|
543 |
|
544 |
|
545 @staticmethod |
|
546 def from_json(json_str): |
|
547 """Deserialize KeyRange from its json representation. |
|
548 |
|
549 Args: |
|
550 json_str: string with json representation created by key_range_to_json. |
|
551 |
|
552 Returns: |
|
553 deserialized KeyRange instance. |
|
554 """ |
|
555 if simplejson is None: |
|
556 raise SimplejsonUnavailableError( |
|
557 "JSON functionality requires simplejson to be available") |
|
558 |
|
559 def key_from_str(key_str): |
|
560 if key_str: |
|
561 return db.Key(key_str) |
|
562 else: |
|
563 return None |
|
564 |
|
565 json = simplejson.loads(json_str) |
|
566 return KeyRange(key_from_str(json["key_start"]), |
|
567 key_from_str(json["key_end"]), |
|
568 json["direction"], |
|
569 json["include_start"], |
|
570 json["include_end"]) |