thirdparty/google_appengine/google/net/proto/message_set.py
author Pawel Solyga <Pawel.Solyga@gmail.com>
Sun, 06 Sep 2009 23:31:53 +0200
changeset 2864 2e0b0af889be
permissions -rwxr-xr-x
Update Google App Engine from 1.2.3 to 1.2.5 in thirdparty folder.

#!/usr/bin/env python
#
# Copyright 2007 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""This module contains the MessageSet class, which is a special kind of
protocol message which can contain other protocol messages without knowing
their types.  See the class's doc string for more information."""


from google.net.proto import ProtocolBuffer
import logging

TAG_BEGIN_ITEM_GROUP = 11
TAG_END_ITEM_GROUP   = 12
TAG_TYPE_ID          = 16
TAG_MESSAGE          = 26

class Item:

  def __init__(self, message, message_class=None):
    self.message = message
    self.message_class = message_class

  def SetToDefaultInstance(self, message_class):
    self.message = message_class()
    self.message_class = message_class

  def Parse(self, message_class):

    if self.message_class is not None:
      return 1

    try:
      self.message = message_class(self.message)
      self.message_class = message_class
      return 1
    except ProtocolBuffer.ProtocolBufferDecodeError:
      logging.warn("Parse error in message inside MessageSet.  Tried "
                   "to parse as: " + message_class.__name__)
      return 0

  def MergeFrom(self, other):

    if self.message_class is not None:
      if other.Parse(self.message_class):
        self.message.MergeFrom(other.message)

    elif other.message_class is not None:
      if not self.Parse(other.message_class):
        self.message = other.message_class()
        self.message_class = other.message_class
      self.message.MergeFrom(other.message)

    else:
      self.message += other.message

  def Copy(self):

    if self.message_class is None:
      return Item(self.message)
    else:
      new_message = self.message_class()
      new_message.CopyFrom(self.message)
      return Item(new_message, self.message_class)

  def Equals(self, other):

    if self.message_class is not None:
      if not other.Parse(self.message_class): return 0
      return self.message.Equals(other.message)

    elif other.message_class is not None:
      if not self.Parse(other.message_class): return 0
      return self.message.Equals(other.message)

    else:
      return self.message == other.message

  def IsInitialized(self, debug_strs=None):

    if self.message_class is None:
      return 1
    else:
      return self.message.IsInitialized(debug_strs)

  def ByteSize(self, pb, type_id):

    message_length = 0
    if self.message_class is None:
      message_length = len(self.message)
    else:
      message_length = self.message.ByteSize()

    return pb.lengthString(message_length) + pb.lengthVarInt64(type_id) + 2

  def OutputUnchecked(self, out, type_id):

    out.putVarInt32(TAG_TYPE_ID)
    out.putVarUint64(type_id)
    out.putVarInt32(TAG_MESSAGE)
    if self.message_class is None:
      out.putPrefixedString(self.message)
    else:
      out.putVarInt32(self.message.ByteSize())
      self.message.OutputUnchecked(out)

  def Decode(decoder):

    type_id = 0
    message = None
    while 1:
      tag = decoder.getVarInt32()
      if tag == TAG_END_ITEM_GROUP:
        break
      if tag == TAG_TYPE_ID:
        type_id = decoder.getVarUint64()
        continue
      if tag == TAG_MESSAGE:
        message = decoder.getPrefixedString()
        continue
      if tag == 0: raise ProtocolBuffer.ProtocolBufferDecodeError
      decoder.skipData(tag)

    if type_id == 0 or message is None:
      raise ProtocolBuffer.ProtocolBufferDecodeError
    return (type_id, message)
  Decode = staticmethod(Decode)


class MessageSet(ProtocolBuffer.ProtocolMessage):

  def __init__(self, contents=None):
    self.items = dict()
    if contents is not None: self.MergeFromString(contents)


  def get(self, message_class):

    if message_class.MESSAGE_TYPE_ID not in self.items:
      return message_class()
    item = self.items[message_class.MESSAGE_TYPE_ID]
    if item.Parse(message_class):
      return item.message
    else:
      return message_class()

  def mutable(self, message_class):

    if message_class.MESSAGE_TYPE_ID not in self.items:
      message = message_class()
      self.items[message_class.MESSAGE_TYPE_ID] = Item(message, message_class)
      return message
    item = self.items[message_class.MESSAGE_TYPE_ID]
    if not item.Parse(message_class):
      item.SetToDefaultInstance(message_class)
    return item.message

  def has(self, message_class):

    if message_class.MESSAGE_TYPE_ID not in self.items:
      return 0
    item = self.items[message_class.MESSAGE_TYPE_ID]
    return item.Parse(message_class)

  def has_unparsed(self, message_class):
    return message_class.MESSAGE_TYPE_ID in self.items

  def GetTypeIds(self):
    return self.items.keys()

  def NumMessages(self):
    return len(self.items)

  def remove(self, message_class):
    if message_class.MESSAGE_TYPE_ID in self.items:
      del self.items[message_class.MESSAGE_TYPE_ID]


  def __getitem__(self, message_class):
    if message_class.MESSAGE_TYPE_ID not in self.items:
      raise KeyError(message_class)
    item = self.items[message_class.MESSAGE_TYPE_ID]
    if item.Parse(message_class):
      return item.message
    else:
      raise KeyError(message_class)

  def __setitem__(self, message_class, message):
    self.items[message_class.MESSAGE_TYPE_ID] = Item(message, message_class)

  def __contains__(self, message_class):
    return self.has(message_class)

  def __delitem__(self, message_class):
    self.remove(message_class)

  def __len__(self):
    return len(self.items)


  def MergeFrom(self, other):

    assert other is not self

    for (type_id, item) in other.items.items():
      if type_id in self.items:
        self.items[type_id].MergeFrom(item)
      else:
        self.items[type_id] = item.Copy()

  def Equals(self, other):
    if other is self: return 1
    if len(self.items) != len(other.items): return 0

    for (type_id, item) in other.items.items():
      if type_id not in self.items: return 0
      if not self.items[type_id].Equals(item): return 0

    return 1

  def __eq__(self, other):
    return ((other is not None)
        and (other.__class__ == self.__class__)
        and self.Equals(other))

  def __ne__(self, other):
    return not (self == other)

  def IsInitialized(self, debug_strs=None):

    initialized = 1
    for item in self.items.values():
      if not item.IsInitialized(debug_strs):
        initialized = 0
    return initialized

  def ByteSize(self):
    n = 2 * len(self.items)
    for (type_id, item) in self.items.items():
      n += item.ByteSize(self, type_id)
    return n

  def Clear(self):
    self.items = dict()

  def OutputUnchecked(self, out):
    for (type_id, item) in self.items.items():
      out.putVarInt32(TAG_BEGIN_ITEM_GROUP)
      item.OutputUnchecked(out, type_id)
      out.putVarInt32(TAG_END_ITEM_GROUP)

  def TryMerge(self, decoder):
    while decoder.avail() > 0:
      tag = decoder.getVarInt32()
      if tag == TAG_BEGIN_ITEM_GROUP:
        (type_id, message) = Item.Decode(decoder)
        if type_id in self.items:
          self.items[type_id].MergeFrom(Item(message))
        else:
          self.items[type_id] = Item(message)
        continue
      if (tag == 0): raise ProtocolBuffer.ProtocolBufferDecodeError
      decoder.skipData(tag)

  def __str__(self, prefix="", printElemNumber=0):
    text = ""
    for (type_id, item) in self.items.items():
      if item.message_class is None:
        text += "%s[%d] <\n" % (prefix, type_id)
        text += "%s  (%d bytes)\n" % (prefix, len(item.message))
        text += "%s>\n" % prefix
      else:
        text += "%s[%s] <\n" % (prefix, item.message_class.__name__)
        text += item.message.__str__(prefix + "  ", printElemNumber)
        text += "%s>\n" % prefix
    return text

__all__ = ['MessageSet']