|
1 #!/usr/bin/env python |
|
2 # |
|
3 # Copyright 2007 Google Inc. |
|
4 # |
|
5 # Licensed under the Apache License, Version 2.0 (the "License"); |
|
6 # you may not use this file except in compliance with the License. |
|
7 # You may obtain a copy of the License at |
|
8 # |
|
9 # http://www.apache.org/licenses/LICENSE-2.0 |
|
10 # |
|
11 # Unless required by applicable law or agreed to in writing, software |
|
12 # distributed under the License is distributed on an "AS IS" BASIS, |
|
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
14 # See the License for the specific language governing permissions and |
|
15 # limitations under the License. |
|
16 # |
|
17 |
|
18 |
|
19 import struct |
|
20 import array |
|
21 import string |
|
22 import re |
|
23 from google.pyglib.gexcept import AbstractMethod |
|
24 import httplib |
|
25 |
|
26 __all__ = ['ProtocolMessage', 'Encoder', 'Decoder', |
|
27 'ProtocolBufferDecodeError', |
|
28 'ProtocolBufferEncodeError', |
|
29 'ProtocolBufferReturnError'] |
|
30 |
|
31 URL_RE = re.compile('^(https?)://([^/]+)(/.*)$') |
|
32 |
|
33 class ProtocolMessage: |
|
34 |
|
35 |
|
36 def __init__(self, contents=None): |
|
37 raise AbstractMethod |
|
38 |
|
39 def Clear(self): |
|
40 raise AbstractMethod |
|
41 |
|
42 def IsInitialized(self, debug_strs=None): |
|
43 raise AbstractMethod |
|
44 |
|
45 def Encode(self): |
|
46 try: |
|
47 return self._CEncode() |
|
48 except AbstractMethod: |
|
49 e = Encoder() |
|
50 self.Output(e) |
|
51 return e.buffer().tostring() |
|
52 |
|
53 def _CEncode(self): |
|
54 raise AbstractMethod |
|
55 |
|
56 def ParseFromString(self, s): |
|
57 self.Clear() |
|
58 self.MergeFromString(s) |
|
59 return |
|
60 |
|
61 def MergeFromString(self, s): |
|
62 try: |
|
63 self._CMergeFromString(s) |
|
64 dbg = [] |
|
65 if not self.IsInitialized(dbg): |
|
66 raise ProtocolBufferDecodeError, '\n\t'.join(dbg) |
|
67 except AbstractMethod: |
|
68 a = array.array('B') |
|
69 a.fromstring(s) |
|
70 d = Decoder(a, 0, len(a)) |
|
71 self.Merge(d) |
|
72 return |
|
73 |
|
74 def _CMergeFromString(self, s): |
|
75 raise AbstractMethod |
|
76 |
|
77 def __getstate__(self): |
|
78 return self.Encode() |
|
79 |
|
80 def __setstate__(self, contents_): |
|
81 self.__init__(contents=contents_) |
|
82 |
|
83 def sendCommand(self, server, url, response, follow_redirects=1, |
|
84 secure=0, keyfile=None, certfile=None): |
|
85 data = self.Encode() |
|
86 if secure: |
|
87 if keyfile and certfile: |
|
88 conn = httplib.HTTPSConnection(server, key_file=keyfile, |
|
89 cert_file=certfile) |
|
90 else: |
|
91 conn = httplib.HTTPSConnection(server) |
|
92 else: |
|
93 conn = httplib.HTTPConnection(server) |
|
94 conn.putrequest("POST", url) |
|
95 conn.putheader("Content-Length", "%d" %len(data)) |
|
96 conn.endheaders() |
|
97 conn.send(data) |
|
98 resp = conn.getresponse() |
|
99 if follow_redirects > 0 and resp.status == 302: |
|
100 m = URL_RE.match(resp.getheader('Location')) |
|
101 if m: |
|
102 protocol, server, url = m.groups() |
|
103 return self.sendCommand(server, url, response, |
|
104 follow_redirects=follow_redirects - 1, |
|
105 secure=(protocol == 'https'), |
|
106 keyfile=keyfile, |
|
107 certfile=certfile) |
|
108 if resp.status != 200: |
|
109 raise ProtocolBufferReturnError(resp.status) |
|
110 if response is not None: |
|
111 response.ParseFromString(resp.read()) |
|
112 return response |
|
113 |
|
114 def sendSecureCommand(self, server, keyfile, certfile, url, response, |
|
115 follow_redirects=1): |
|
116 return self.sendCommand(server, url, response, |
|
117 follow_redirects=follow_redirects, |
|
118 secure=1, keyfile=keyfile, certfile=certfile) |
|
119 |
|
120 def __str__(self, prefix="", printElemNumber=0): |
|
121 raise AbstractMethod |
|
122 |
|
123 def ToASCII(self): |
|
124 return self._CToASCII(ProtocolMessage._SYMBOLIC_FULL_ASCII) |
|
125 |
|
126 def ToCompactASCII(self): |
|
127 return self._CToASCII(ProtocolMessage._NUMERIC_ASCII) |
|
128 |
|
129 def ToShortASCII(self): |
|
130 return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII) |
|
131 |
|
132 _NUMERIC_ASCII = 0 |
|
133 _SYMBOLIC_SHORT_ASCII = 1 |
|
134 _SYMBOLIC_FULL_ASCII = 2 |
|
135 |
|
136 def _CToASCII(self, output_format): |
|
137 raise AbstractMethod |
|
138 |
|
139 def ParseASCII(self, ascii_string): |
|
140 raise AbstractMethod |
|
141 |
|
142 def ParseASCIIIgnoreUnknown(self, ascii_string): |
|
143 raise AbstractMethod |
|
144 |
|
145 |
|
146 def Output(self, e): |
|
147 dbg = [] |
|
148 if not self.IsInitialized(dbg): |
|
149 raise ProtocolBufferEncodeError, '\n\t'.join(dbg) |
|
150 self.OutputUnchecked(e) |
|
151 return |
|
152 |
|
153 def OutputUnchecked(self, e): |
|
154 raise AbstractMethod |
|
155 |
|
156 def Parse(self, d): |
|
157 self.Clear() |
|
158 self.Merge(d) |
|
159 return |
|
160 |
|
161 def Merge(self, d): |
|
162 self.TryMerge(d) |
|
163 dbg = [] |
|
164 if not self.IsInitialized(dbg): |
|
165 raise ProtocolBufferDecodeError, '\n\t'.join(dbg) |
|
166 return |
|
167 |
|
168 def TryMerge(self, d): |
|
169 raise AbstractMethod |
|
170 |
|
171 def CopyFrom(self, pb): |
|
172 if (pb == self): return |
|
173 self.Clear() |
|
174 self.MergeFrom(pb) |
|
175 |
|
176 def MergeFrom(self, pb): |
|
177 raise AbstractMethod |
|
178 |
|
179 |
|
180 def lengthVarInt32(self, n): |
|
181 return self.lengthVarInt64(n) |
|
182 |
|
183 def lengthVarInt64(self, n): |
|
184 if n < 0: |
|
185 return 10 |
|
186 result = 0 |
|
187 while 1: |
|
188 result += 1 |
|
189 n >>= 7 |
|
190 if n == 0: |
|
191 break |
|
192 return result |
|
193 |
|
194 def lengthString(self, n): |
|
195 return self.lengthVarInt32(n) + n |
|
196 |
|
197 def DebugFormat(self, value): |
|
198 return "%s" % value |
|
199 def DebugFormatInt32(self, value): |
|
200 if (value <= -2000000000 or value >= 2000000000): |
|
201 return self.DebugFormatFixed32(value) |
|
202 return "%d" % value |
|
203 def DebugFormatInt64(self, value): |
|
204 if (value <= -2000000000 or value >= 2000000000): |
|
205 return self.DebugFormatFixed64(value) |
|
206 return "%d" % value |
|
207 def DebugFormatString(self, value): |
|
208 def escape(c): |
|
209 o = ord(c) |
|
210 if o == 10: return r"\n" |
|
211 if o == 39: return r"\'" |
|
212 |
|
213 if o == 34: return r'\"' |
|
214 if o == 92: return r"\\" |
|
215 |
|
216 if o >= 127 or o < 32: return "\\%03o" % o |
|
217 return c |
|
218 return '"' + "".join([escape(c) for c in value]) + '"' |
|
219 def DebugFormatFloat(self, value): |
|
220 return "%ff" % value |
|
221 def DebugFormatFixed32(self, value): |
|
222 if (value < 0): value += (1L<<32) |
|
223 return "0x%x" % value |
|
224 def DebugFormatFixed64(self, value): |
|
225 if (value < 0): value += (1L<<64) |
|
226 return "0x%x" % value |
|
227 def DebugFormatBool(self, value): |
|
228 if value: |
|
229 return "true" |
|
230 else: |
|
231 return "false" |
|
232 |
|
233 class Encoder: |
|
234 |
|
235 NUMERIC = 0 |
|
236 DOUBLE = 1 |
|
237 STRING = 2 |
|
238 STARTGROUP = 3 |
|
239 ENDGROUP = 4 |
|
240 FLOAT = 5 |
|
241 MAX_TYPE = 6 |
|
242 |
|
243 def __init__(self): |
|
244 self.buf = array.array('B') |
|
245 return |
|
246 |
|
247 def buffer(self): |
|
248 return self.buf |
|
249 |
|
250 def put8(self, v): |
|
251 if v < 0 or v >= (1<<8): raise ProtocolBufferEncodeError, "u8 too big" |
|
252 self.buf.append(v & 255) |
|
253 return |
|
254 |
|
255 def put16(self, v): |
|
256 if v < 0 or v >= (1<<16): raise ProtocolBufferEncodeError, "u16 too big" |
|
257 self.buf.append((v >> 0) & 255) |
|
258 self.buf.append((v >> 8) & 255) |
|
259 return |
|
260 |
|
261 def put32(self, v): |
|
262 if v < 0 or v >= (1L<<32): raise ProtocolBufferEncodeError, "u32 too big" |
|
263 self.buf.append((v >> 0) & 255) |
|
264 self.buf.append((v >> 8) & 255) |
|
265 self.buf.append((v >> 16) & 255) |
|
266 self.buf.append((v >> 24) & 255) |
|
267 return |
|
268 |
|
269 def put64(self, v): |
|
270 if v < 0 or v >= (1L<<64): raise ProtocolBufferEncodeError, "u64 too big" |
|
271 self.buf.append((v >> 0) & 255) |
|
272 self.buf.append((v >> 8) & 255) |
|
273 self.buf.append((v >> 16) & 255) |
|
274 self.buf.append((v >> 24) & 255) |
|
275 self.buf.append((v >> 32) & 255) |
|
276 self.buf.append((v >> 40) & 255) |
|
277 self.buf.append((v >> 48) & 255) |
|
278 self.buf.append((v >> 56) & 255) |
|
279 return |
|
280 |
|
281 def putVarInt32(self, v): |
|
282 if v >= (1L << 31) or v < -(1L << 31): |
|
283 raise ProtocolBufferEncodeError, "int32 too big" |
|
284 self.putVarInt64(v) |
|
285 return |
|
286 |
|
287 def putVarInt64(self, v): |
|
288 if v >= (1L << 63) or v < -(1L << 63): |
|
289 raise ProtocolBufferEncodeError, "int64 too big" |
|
290 if v < 0: |
|
291 v += (1L << 64) |
|
292 self.putVarUint64(v) |
|
293 return |
|
294 |
|
295 def putVarUint64(self, v): |
|
296 if v < 0 or v >= (1L << 64): |
|
297 raise ProtocolBufferEncodeError, "uint64 too big" |
|
298 while 1: |
|
299 bits = v & 127 |
|
300 v >>= 7 |
|
301 if (v != 0): |
|
302 bits |= 128 |
|
303 self.buf.append(bits) |
|
304 if v == 0: |
|
305 break |
|
306 return |
|
307 |
|
308 |
|
309 def putFloat(self, v): |
|
310 a = array.array('B') |
|
311 a.fromstring(struct.pack("f", v)) |
|
312 self.buf.extend(a) |
|
313 return |
|
314 |
|
315 def putDouble(self, v): |
|
316 a = array.array('B') |
|
317 a.fromstring(struct.pack("d", v)) |
|
318 self.buf.extend(a) |
|
319 return |
|
320 |
|
321 def putBoolean(self, v): |
|
322 if v: |
|
323 self.buf.append(1) |
|
324 else: |
|
325 self.buf.append(0) |
|
326 return |
|
327 |
|
328 def putPrefixedString(self, v): |
|
329 self.putVarInt32(len(v)) |
|
330 a = array.array('B') |
|
331 a.fromstring(v) |
|
332 self.buf.extend(a) |
|
333 return |
|
334 |
|
335 def putRawString(self, v): |
|
336 a = array.array('B') |
|
337 a.fromstring(v) |
|
338 self.buf.extend(a) |
|
339 |
|
340 |
|
341 class Decoder: |
|
342 def __init__(self, buf, idx, limit): |
|
343 self.buf = buf |
|
344 self.idx = idx |
|
345 self.limit = limit |
|
346 return |
|
347 |
|
348 def avail(self): |
|
349 return self.limit - self.idx |
|
350 |
|
351 def buffer(self): |
|
352 return self.buf |
|
353 |
|
354 def pos(self): |
|
355 return self.idx |
|
356 |
|
357 def skip(self, n): |
|
358 if self.idx + n > self.limit: raise ProtocolBufferDecodeError, "truncated" |
|
359 self.idx += n |
|
360 return |
|
361 |
|
362 def skipData(self, tag): |
|
363 t = tag & 7 |
|
364 if t == Encoder.NUMERIC: |
|
365 self.getVarInt64() |
|
366 elif t == Encoder.DOUBLE: |
|
367 self.skip(8) |
|
368 elif t == Encoder.STRING: |
|
369 n = self.getVarInt32() |
|
370 self.skip(n) |
|
371 elif t == Encoder.STARTGROUP: |
|
372 while 1: |
|
373 t = self.getVarInt32() |
|
374 if (t & 7) == Encoder.ENDGROUP: |
|
375 break |
|
376 else: |
|
377 self.skipData(t) |
|
378 if (t - Encoder.ENDGROUP) != (tag - Encoder.STARTGROUP): |
|
379 raise ProtocolBufferDecodeError, "corrupted" |
|
380 elif t == Encoder.ENDGROUP: |
|
381 raise ProtocolBufferDecodeError, "corrupted" |
|
382 elif t == Encoder.FLOAT: |
|
383 self.skip(4) |
|
384 else: |
|
385 raise ProtocolBufferDecodeError, "corrupted" |
|
386 |
|
387 def get8(self): |
|
388 if self.idx >= self.limit: raise ProtocolBufferDecodeError, "truncated" |
|
389 c = self.buf[self.idx] |
|
390 self.idx += 1 |
|
391 return c |
|
392 |
|
393 def get16(self): |
|
394 if self.idx + 2 > self.limit: raise ProtocolBufferDecodeError, "truncated" |
|
395 c = self.buf[self.idx] |
|
396 d = self.buf[self.idx + 1] |
|
397 self.idx += 2 |
|
398 return (d << 8) | c |
|
399 |
|
400 def get32(self): |
|
401 if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError, "truncated" |
|
402 c = self.buf[self.idx] |
|
403 d = self.buf[self.idx + 1] |
|
404 e = self.buf[self.idx + 2] |
|
405 f = long(self.buf[self.idx + 3]) |
|
406 self.idx += 4 |
|
407 return (f << 24) | (e << 16) | (d << 8) | c |
|
408 |
|
409 def get64(self): |
|
410 if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError, "truncated" |
|
411 c = self.buf[self.idx] |
|
412 d = self.buf[self.idx + 1] |
|
413 e = self.buf[self.idx + 2] |
|
414 f = long(self.buf[self.idx + 3]) |
|
415 g = long(self.buf[self.idx + 4]) |
|
416 h = long(self.buf[self.idx + 5]) |
|
417 i = long(self.buf[self.idx + 6]) |
|
418 j = long(self.buf[self.idx + 7]) |
|
419 self.idx += 8 |
|
420 return ((j << 56) | (i << 48) | (h << 40) | (g << 32) | (f << 24) |
|
421 | (e << 16) | (d << 8) | c) |
|
422 |
|
423 def getVarInt32(self): |
|
424 v = self.getVarInt64() |
|
425 if v >= (1L << 31) or v < -(1L << 31): |
|
426 raise ProtocolBufferDecodeError, "corrupted" |
|
427 return v |
|
428 |
|
429 def getVarInt64(self): |
|
430 result = self.getVarUint64() |
|
431 if result >= (1L << 63): |
|
432 result -= (1L << 64) |
|
433 return result |
|
434 |
|
435 def getVarUint64(self): |
|
436 result = long(0) |
|
437 shift = 0 |
|
438 while 1: |
|
439 if shift >= 64: raise ProtocolBufferDecodeError, "corrupted" |
|
440 b = self.get8() |
|
441 result |= (long(b & 127) << shift) |
|
442 shift += 7 |
|
443 if (b & 128) == 0: |
|
444 if result >= (1L << 64): raise ProtocolBufferDecodeError, "corrupted" |
|
445 return result |
|
446 return result |
|
447 |
|
448 def getFloat(self): |
|
449 if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError, "truncated" |
|
450 a = self.buf[self.idx:self.idx+4] |
|
451 self.idx += 4 |
|
452 return struct.unpack("f", a)[0] |
|
453 |
|
454 def getDouble(self): |
|
455 if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError, "truncated" |
|
456 a = self.buf[self.idx:self.idx+8] |
|
457 self.idx += 8 |
|
458 return struct.unpack("d", a)[0] |
|
459 |
|
460 def getBoolean(self): |
|
461 b = self.get8() |
|
462 if b != 0 and b != 1: raise ProtocolBufferDecodeError, "corrupted" |
|
463 return b |
|
464 |
|
465 def getPrefixedString(self): |
|
466 length = self.getVarInt32() |
|
467 if self.idx + length > self.limit: |
|
468 raise ProtocolBufferDecodeError, "truncated" |
|
469 r = self.buf[self.idx : self.idx + length] |
|
470 self.idx += length |
|
471 return r.tostring() |
|
472 |
|
473 def getRawString(self): |
|
474 r = self.buf[self.idx:self.limit] |
|
475 self.idx = self.limit |
|
476 return r.tostring() |
|
477 |
|
478 |
|
479 class ProtocolBufferDecodeError(Exception): pass |
|
480 class ProtocolBufferEncodeError(Exception): pass |
|
481 class ProtocolBufferReturnError(Exception): pass |