|
1 #!/usr/bin/env python |
|
2 # |
|
3 # Copyright 2007 Google Inc. |
|
4 # |
|
5 # Licensed under the Apache License, Version 2.0 (the "License"); |
|
6 # you may not use this file except in compliance with the License. |
|
7 # You may obtain a copy of the License at |
|
8 # |
|
9 # http://www.apache.org/licenses/LICENSE-2.0 |
|
10 # |
|
11 # Unless required by applicable law or agreed to in writing, software |
|
12 # distributed under the License is distributed on an "AS IS" BASIS, |
|
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
14 # See the License for the specific language governing permissions and |
|
15 # limitations under the License. |
|
16 # |
|
17 |
|
18 """This module contains the MessageSet class, which is a special kind of |
|
19 protocol message which can contain other protocol messages without knowing |
|
20 their types. See the class's doc string for more information.""" |
|
21 |
|
22 |
|
23 from google.net.proto import ProtocolBuffer |
|
24 import logging |
|
25 |
|
26 TAG_BEGIN_ITEM_GROUP = 11 |
|
27 TAG_END_ITEM_GROUP = 12 |
|
28 TAG_TYPE_ID = 16 |
|
29 TAG_MESSAGE = 26 |
|
30 |
|
31 class Item: |
|
32 |
|
33 def __init__(self, message, message_class=None): |
|
34 self.message = message |
|
35 self.message_class = message_class |
|
36 |
|
37 def SetToDefaultInstance(self, message_class): |
|
38 self.message = message_class() |
|
39 self.message_class = message_class |
|
40 |
|
41 def Parse(self, message_class): |
|
42 |
|
43 if self.message_class is not None: |
|
44 return 1 |
|
45 |
|
46 try: |
|
47 self.message = message_class(self.message) |
|
48 self.message_class = message_class |
|
49 return 1 |
|
50 except ProtocolBuffer.ProtocolBufferDecodeError: |
|
51 logging.warn("Parse error in message inside MessageSet. Tried " |
|
52 "to parse as: " + message_class.__name__) |
|
53 return 0 |
|
54 |
|
55 def MergeFrom(self, other): |
|
56 |
|
57 if self.message_class is not None: |
|
58 if other.Parse(self.message_class): |
|
59 self.message.MergeFrom(other.message) |
|
60 |
|
61 elif other.message_class is not None: |
|
62 if not self.Parse(other.message_class): |
|
63 self.message = other.message_class() |
|
64 self.message_class = other.message_class |
|
65 self.message.MergeFrom(other.message) |
|
66 |
|
67 else: |
|
68 self.message += other.message |
|
69 |
|
70 def Copy(self): |
|
71 |
|
72 if self.message_class is None: |
|
73 return Item(self.message) |
|
74 else: |
|
75 new_message = self.message_class() |
|
76 new_message.CopyFrom(self.message) |
|
77 return Item(new_message, self.message_class) |
|
78 |
|
79 def Equals(self, other): |
|
80 |
|
81 if self.message_class is not None: |
|
82 if not other.Parse(self.message_class): return 0 |
|
83 return self.message.Equals(other.message) |
|
84 |
|
85 elif other.message_class is not None: |
|
86 if not self.Parse(other.message_class): return 0 |
|
87 return self.message.Equals(other.message) |
|
88 |
|
89 else: |
|
90 return self.message == other.message |
|
91 |
|
92 def IsInitialized(self, debug_strs=None): |
|
93 |
|
94 if self.message_class is None: |
|
95 return 1 |
|
96 else: |
|
97 return self.message.IsInitialized(debug_strs) |
|
98 |
|
99 def ByteSize(self, pb, type_id): |
|
100 |
|
101 message_length = 0 |
|
102 if self.message_class is None: |
|
103 message_length = len(self.message) |
|
104 else: |
|
105 message_length = self.message.ByteSize() |
|
106 |
|
107 return pb.lengthString(message_length) + pb.lengthVarInt64(type_id) + 2 |
|
108 |
|
109 def OutputUnchecked(self, out, type_id): |
|
110 |
|
111 out.putVarInt32(TAG_TYPE_ID) |
|
112 out.putVarUint64(type_id) |
|
113 out.putVarInt32(TAG_MESSAGE) |
|
114 if self.message_class is None: |
|
115 out.putPrefixedString(self.message) |
|
116 else: |
|
117 out.putVarInt32(self.message.ByteSize()) |
|
118 self.message.OutputUnchecked(out) |
|
119 |
|
120 def Decode(decoder): |
|
121 |
|
122 type_id = 0 |
|
123 message = None |
|
124 while 1: |
|
125 tag = decoder.getVarInt32() |
|
126 if tag == TAG_END_ITEM_GROUP: |
|
127 break |
|
128 if tag == TAG_TYPE_ID: |
|
129 type_id = decoder.getVarUint64() |
|
130 continue |
|
131 if tag == TAG_MESSAGE: |
|
132 message = decoder.getPrefixedString() |
|
133 continue |
|
134 if tag == 0: raise ProtocolBuffer.ProtocolBufferDecodeError |
|
135 decoder.skipData(tag) |
|
136 |
|
137 if type_id == 0 or message is None: |
|
138 raise ProtocolBuffer.ProtocolBufferDecodeError |
|
139 return (type_id, message) |
|
140 Decode = staticmethod(Decode) |
|
141 |
|
142 |
|
143 class MessageSet(ProtocolBuffer.ProtocolMessage): |
|
144 |
|
145 def __init__(self, contents=None): |
|
146 self.items = dict() |
|
147 if contents is not None: self.MergeFromString(contents) |
|
148 |
|
149 |
|
150 def get(self, message_class): |
|
151 |
|
152 if message_class.MESSAGE_TYPE_ID not in self.items: |
|
153 return message_class() |
|
154 item = self.items[message_class.MESSAGE_TYPE_ID] |
|
155 if item.Parse(message_class): |
|
156 return item.message |
|
157 else: |
|
158 return message_class() |
|
159 |
|
160 def mutable(self, message_class): |
|
161 |
|
162 if message_class.MESSAGE_TYPE_ID not in self.items: |
|
163 message = message_class() |
|
164 self.items[message_class.MESSAGE_TYPE_ID] = Item(message, message_class) |
|
165 return message |
|
166 item = self.items[message_class.MESSAGE_TYPE_ID] |
|
167 if not item.Parse(message_class): |
|
168 item.SetToDefaultInstance(message_class) |
|
169 return item.message |
|
170 |
|
171 def has(self, message_class): |
|
172 |
|
173 if message_class.MESSAGE_TYPE_ID not in self.items: |
|
174 return 0 |
|
175 item = self.items[message_class.MESSAGE_TYPE_ID] |
|
176 return item.Parse(message_class) |
|
177 |
|
178 def has_unparsed(self, message_class): |
|
179 return message_class.MESSAGE_TYPE_ID in self.items |
|
180 |
|
181 def GetTypeIds(self): |
|
182 return self.items.keys() |
|
183 |
|
184 def NumMessages(self): |
|
185 return len(self.items) |
|
186 |
|
187 def remove(self, message_class): |
|
188 if message_class.MESSAGE_TYPE_ID in self.items: |
|
189 del self.items[message_class.MESSAGE_TYPE_ID] |
|
190 |
|
191 |
|
192 def __getitem__(self, message_class): |
|
193 if message_class.MESSAGE_TYPE_ID not in self.items: |
|
194 raise KeyError(message_class) |
|
195 item = self.items[message_class.MESSAGE_TYPE_ID] |
|
196 if item.Parse(message_class): |
|
197 return item.message |
|
198 else: |
|
199 raise KeyError(message_class) |
|
200 |
|
201 def __setitem__(self, message_class, message): |
|
202 self.items[message_class.MESSAGE_TYPE_ID] = Item(message, message_class) |
|
203 |
|
204 def __contains__(self, message_class): |
|
205 return self.has(message_class) |
|
206 |
|
207 def __delitem__(self, message_class): |
|
208 self.remove(message_class) |
|
209 |
|
210 def __len__(self): |
|
211 return len(self.items) |
|
212 |
|
213 |
|
214 def MergeFrom(self, other): |
|
215 |
|
216 assert other is not self |
|
217 |
|
218 for (type_id, item) in other.items.items(): |
|
219 if type_id in self.items: |
|
220 self.items[type_id].MergeFrom(item) |
|
221 else: |
|
222 self.items[type_id] = item.Copy() |
|
223 |
|
224 def Equals(self, other): |
|
225 if other is self: return 1 |
|
226 if len(self.items) != len(other.items): return 0 |
|
227 |
|
228 for (type_id, item) in other.items.items(): |
|
229 if type_id not in self.items: return 0 |
|
230 if not self.items[type_id].Equals(item): return 0 |
|
231 |
|
232 return 1 |
|
233 |
|
234 def __eq__(self, other): |
|
235 return ((other is not None) |
|
236 and (other.__class__ == self.__class__) |
|
237 and self.Equals(other)) |
|
238 |
|
239 def __ne__(self, other): |
|
240 return not (self == other) |
|
241 |
|
242 def IsInitialized(self, debug_strs=None): |
|
243 |
|
244 initialized = 1 |
|
245 for item in self.items.values(): |
|
246 if not item.IsInitialized(debug_strs): |
|
247 initialized = 0 |
|
248 return initialized |
|
249 |
|
250 def ByteSize(self): |
|
251 n = 2 * len(self.items) |
|
252 for (type_id, item) in self.items.items(): |
|
253 n += item.ByteSize(self, type_id) |
|
254 return n |
|
255 |
|
256 def Clear(self): |
|
257 self.items = dict() |
|
258 |
|
259 def OutputUnchecked(self, out): |
|
260 for (type_id, item) in self.items.items(): |
|
261 out.putVarInt32(TAG_BEGIN_ITEM_GROUP) |
|
262 item.OutputUnchecked(out, type_id) |
|
263 out.putVarInt32(TAG_END_ITEM_GROUP) |
|
264 |
|
265 def TryMerge(self, decoder): |
|
266 while decoder.avail() > 0: |
|
267 tag = decoder.getVarInt32() |
|
268 if tag == TAG_BEGIN_ITEM_GROUP: |
|
269 (type_id, message) = Item.Decode(decoder) |
|
270 if type_id in self.items: |
|
271 self.items[type_id].MergeFrom(Item(message)) |
|
272 else: |
|
273 self.items[type_id] = Item(message) |
|
274 continue |
|
275 if (tag == 0): raise ProtocolBuffer.ProtocolBufferDecodeError |
|
276 decoder.skipData(tag) |
|
277 |
|
278 def __str__(self, prefix="", printElemNumber=0): |
|
279 text = "" |
|
280 for (type_id, item) in self.items.items(): |
|
281 if item.message_class is None: |
|
282 text += "%s[%d] <\n" % (prefix, type_id) |
|
283 text += "%s (%d bytes)\n" % (prefix, len(item.message)) |
|
284 text += "%s>\n" % prefix |
|
285 else: |
|
286 text += "%s[%s] <\n" % (prefix, item.message_class.__name__) |
|
287 text += item.message.__str__(prefix + " ", printElemNumber) |
|
288 text += "%s>\n" % prefix |
|
289 return text |
|
290 |
|
291 __all__ = ['MessageSet'] |