thirdparty/google_appengine/google/net/proto/message_set.py
changeset 2864 2e0b0af889be
equal deleted inserted replaced
2862:27971a13089f 2864:2e0b0af889be
       
     1 #!/usr/bin/env python
       
     2 #
       
     3 # Copyright 2007 Google Inc.
       
     4 #
       
     5 # Licensed under the Apache License, Version 2.0 (the "License");
       
     6 # you may not use this file except in compliance with the License.
       
     7 # You may obtain a copy of the License at
       
     8 #
       
     9 #     http://www.apache.org/licenses/LICENSE-2.0
       
    10 #
       
    11 # Unless required by applicable law or agreed to in writing, software
       
    12 # distributed under the License is distributed on an "AS IS" BASIS,
       
    13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
       
    14 # See the License for the specific language governing permissions and
       
    15 # limitations under the License.
       
    16 #
       
    17 
       
    18 """This module contains the MessageSet class, which is a special kind of
       
    19 protocol message which can contain other protocol messages without knowing
       
    20 their types.  See the class's doc string for more information."""
       
    21 
       
    22 
       
    23 from google.net.proto import ProtocolBuffer
       
    24 import logging
       
    25 
       
    26 TAG_BEGIN_ITEM_GROUP = 11
       
    27 TAG_END_ITEM_GROUP   = 12
       
    28 TAG_TYPE_ID          = 16
       
    29 TAG_MESSAGE          = 26
       
    30 
       
    31 class Item:
       
    32 
       
    33   def __init__(self, message, message_class=None):
       
    34     self.message = message
       
    35     self.message_class = message_class
       
    36 
       
    37   def SetToDefaultInstance(self, message_class):
       
    38     self.message = message_class()
       
    39     self.message_class = message_class
       
    40 
       
    41   def Parse(self, message_class):
       
    42 
       
    43     if self.message_class is not None:
       
    44       return 1
       
    45 
       
    46     try:
       
    47       self.message = message_class(self.message)
       
    48       self.message_class = message_class
       
    49       return 1
       
    50     except ProtocolBuffer.ProtocolBufferDecodeError:
       
    51       logging.warn("Parse error in message inside MessageSet.  Tried "
       
    52                    "to parse as: " + message_class.__name__)
       
    53       return 0
       
    54 
       
    55   def MergeFrom(self, other):
       
    56 
       
    57     if self.message_class is not None:
       
    58       if other.Parse(self.message_class):
       
    59         self.message.MergeFrom(other.message)
       
    60 
       
    61     elif other.message_class is not None:
       
    62       if not self.Parse(other.message_class):
       
    63         self.message = other.message_class()
       
    64         self.message_class = other.message_class
       
    65       self.message.MergeFrom(other.message)
       
    66 
       
    67     else:
       
    68       self.message += other.message
       
    69 
       
    70   def Copy(self):
       
    71 
       
    72     if self.message_class is None:
       
    73       return Item(self.message)
       
    74     else:
       
    75       new_message = self.message_class()
       
    76       new_message.CopyFrom(self.message)
       
    77       return Item(new_message, self.message_class)
       
    78 
       
    79   def Equals(self, other):
       
    80 
       
    81     if self.message_class is not None:
       
    82       if not other.Parse(self.message_class): return 0
       
    83       return self.message.Equals(other.message)
       
    84 
       
    85     elif other.message_class is not None:
       
    86       if not self.Parse(other.message_class): return 0
       
    87       return self.message.Equals(other.message)
       
    88 
       
    89     else:
       
    90       return self.message == other.message
       
    91 
       
    92   def IsInitialized(self, debug_strs=None):
       
    93 
       
    94     if self.message_class is None:
       
    95       return 1
       
    96     else:
       
    97       return self.message.IsInitialized(debug_strs)
       
    98 
       
    99   def ByteSize(self, pb, type_id):
       
   100 
       
   101     message_length = 0
       
   102     if self.message_class is None:
       
   103       message_length = len(self.message)
       
   104     else:
       
   105       message_length = self.message.ByteSize()
       
   106 
       
   107     return pb.lengthString(message_length) + pb.lengthVarInt64(type_id) + 2
       
   108 
       
   109   def OutputUnchecked(self, out, type_id):
       
   110 
       
   111     out.putVarInt32(TAG_TYPE_ID)
       
   112     out.putVarUint64(type_id)
       
   113     out.putVarInt32(TAG_MESSAGE)
       
   114     if self.message_class is None:
       
   115       out.putPrefixedString(self.message)
       
   116     else:
       
   117       out.putVarInt32(self.message.ByteSize())
       
   118       self.message.OutputUnchecked(out)
       
   119 
       
   120   def Decode(decoder):
       
   121 
       
   122     type_id = 0
       
   123     message = None
       
   124     while 1:
       
   125       tag = decoder.getVarInt32()
       
   126       if tag == TAG_END_ITEM_GROUP:
       
   127         break
       
   128       if tag == TAG_TYPE_ID:
       
   129         type_id = decoder.getVarUint64()
       
   130         continue
       
   131       if tag == TAG_MESSAGE:
       
   132         message = decoder.getPrefixedString()
       
   133         continue
       
   134       if tag == 0: raise ProtocolBuffer.ProtocolBufferDecodeError
       
   135       decoder.skipData(tag)
       
   136 
       
   137     if type_id == 0 or message is None:
       
   138       raise ProtocolBuffer.ProtocolBufferDecodeError
       
   139     return (type_id, message)
       
   140   Decode = staticmethod(Decode)
       
   141 
       
   142 
       
   143 class MessageSet(ProtocolBuffer.ProtocolMessage):
       
   144 
       
   145   def __init__(self, contents=None):
       
   146     self.items = dict()
       
   147     if contents is not None: self.MergeFromString(contents)
       
   148 
       
   149 
       
   150   def get(self, message_class):
       
   151 
       
   152     if message_class.MESSAGE_TYPE_ID not in self.items:
       
   153       return message_class()
       
   154     item = self.items[message_class.MESSAGE_TYPE_ID]
       
   155     if item.Parse(message_class):
       
   156       return item.message
       
   157     else:
       
   158       return message_class()
       
   159 
       
   160   def mutable(self, message_class):
       
   161 
       
   162     if message_class.MESSAGE_TYPE_ID not in self.items:
       
   163       message = message_class()
       
   164       self.items[message_class.MESSAGE_TYPE_ID] = Item(message, message_class)
       
   165       return message
       
   166     item = self.items[message_class.MESSAGE_TYPE_ID]
       
   167     if not item.Parse(message_class):
       
   168       item.SetToDefaultInstance(message_class)
       
   169     return item.message
       
   170 
       
   171   def has(self, message_class):
       
   172 
       
   173     if message_class.MESSAGE_TYPE_ID not in self.items:
       
   174       return 0
       
   175     item = self.items[message_class.MESSAGE_TYPE_ID]
       
   176     return item.Parse(message_class)
       
   177 
       
   178   def has_unparsed(self, message_class):
       
   179     return message_class.MESSAGE_TYPE_ID in self.items
       
   180 
       
   181   def GetTypeIds(self):
       
   182     return self.items.keys()
       
   183 
       
   184   def NumMessages(self):
       
   185     return len(self.items)
       
   186 
       
   187   def remove(self, message_class):
       
   188     if message_class.MESSAGE_TYPE_ID in self.items:
       
   189       del self.items[message_class.MESSAGE_TYPE_ID]
       
   190 
       
   191 
       
   192   def __getitem__(self, message_class):
       
   193     if message_class.MESSAGE_TYPE_ID not in self.items:
       
   194       raise KeyError(message_class)
       
   195     item = self.items[message_class.MESSAGE_TYPE_ID]
       
   196     if item.Parse(message_class):
       
   197       return item.message
       
   198     else:
       
   199       raise KeyError(message_class)
       
   200 
       
   201   def __setitem__(self, message_class, message):
       
   202     self.items[message_class.MESSAGE_TYPE_ID] = Item(message, message_class)
       
   203 
       
   204   def __contains__(self, message_class):
       
   205     return self.has(message_class)
       
   206 
       
   207   def __delitem__(self, message_class):
       
   208     self.remove(message_class)
       
   209 
       
   210   def __len__(self):
       
   211     return len(self.items)
       
   212 
       
   213 
       
   214   def MergeFrom(self, other):
       
   215 
       
   216     assert other is not self
       
   217 
       
   218     for (type_id, item) in other.items.items():
       
   219       if type_id in self.items:
       
   220         self.items[type_id].MergeFrom(item)
       
   221       else:
       
   222         self.items[type_id] = item.Copy()
       
   223 
       
   224   def Equals(self, other):
       
   225     if other is self: return 1
       
   226     if len(self.items) != len(other.items): return 0
       
   227 
       
   228     for (type_id, item) in other.items.items():
       
   229       if type_id not in self.items: return 0
       
   230       if not self.items[type_id].Equals(item): return 0
       
   231 
       
   232     return 1
       
   233 
       
   234   def __eq__(self, other):
       
   235     return ((other is not None)
       
   236         and (other.__class__ == self.__class__)
       
   237         and self.Equals(other))
       
   238 
       
   239   def __ne__(self, other):
       
   240     return not (self == other)
       
   241 
       
   242   def IsInitialized(self, debug_strs=None):
       
   243 
       
   244     initialized = 1
       
   245     for item in self.items.values():
       
   246       if not item.IsInitialized(debug_strs):
       
   247         initialized = 0
       
   248     return initialized
       
   249 
       
   250   def ByteSize(self):
       
   251     n = 2 * len(self.items)
       
   252     for (type_id, item) in self.items.items():
       
   253       n += item.ByteSize(self, type_id)
       
   254     return n
       
   255 
       
   256   def Clear(self):
       
   257     self.items = dict()
       
   258 
       
   259   def OutputUnchecked(self, out):
       
   260     for (type_id, item) in self.items.items():
       
   261       out.putVarInt32(TAG_BEGIN_ITEM_GROUP)
       
   262       item.OutputUnchecked(out, type_id)
       
   263       out.putVarInt32(TAG_END_ITEM_GROUP)
       
   264 
       
   265   def TryMerge(self, decoder):
       
   266     while decoder.avail() > 0:
       
   267       tag = decoder.getVarInt32()
       
   268       if tag == TAG_BEGIN_ITEM_GROUP:
       
   269         (type_id, message) = Item.Decode(decoder)
       
   270         if type_id in self.items:
       
   271           self.items[type_id].MergeFrom(Item(message))
       
   272         else:
       
   273           self.items[type_id] = Item(message)
       
   274         continue
       
   275       if (tag == 0): raise ProtocolBuffer.ProtocolBufferDecodeError
       
   276       decoder.skipData(tag)
       
   277 
       
   278   def __str__(self, prefix="", printElemNumber=0):
       
   279     text = ""
       
   280     for (type_id, item) in self.items.items():
       
   281       if item.message_class is None:
       
   282         text += "%s[%d] <\n" % (prefix, type_id)
       
   283         text += "%s  (%d bytes)\n" % (prefix, len(item.message))
       
   284         text += "%s>\n" % prefix
       
   285       else:
       
   286         text += "%s[%s] <\n" % (prefix, item.message_class.__name__)
       
   287         text += item.message.__str__(prefix + "  ", printElemNumber)
       
   288         text += "%s>\n" % prefix
       
   289     return text
       
   290 
       
   291 __all__ = ['MessageSet']