eggs/mercurial-1.7.3-py2.6-linux-x86_64.egg/mercurial/sshserver.py
changeset 69 c6bca38c1cbf
equal deleted inserted replaced
68:5ff1fc726848 69:c6bca38c1cbf
       
     1 # sshserver.py - ssh protocol server support for mercurial
       
     2 #
       
     3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
       
     4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
       
     5 #
       
     6 # This software may be used and distributed according to the terms of the
       
     7 # GNU General Public License version 2 or any later version.
       
     8 
       
     9 import util, hook, wireproto, changegroup
       
    10 import os, sys
       
    11 
       
    12 class sshserver(object):
       
    13     def __init__(self, ui, repo):
       
    14         self.ui = ui
       
    15         self.repo = repo
       
    16         self.lock = None
       
    17         self.fin = sys.stdin
       
    18         self.fout = sys.stdout
       
    19 
       
    20         hook.redirect(True)
       
    21         sys.stdout = sys.stderr
       
    22 
       
    23         # Prevent insertion/deletion of CRs
       
    24         util.set_binary(self.fin)
       
    25         util.set_binary(self.fout)
       
    26 
       
    27     def getargs(self, args):
       
    28         data = {}
       
    29         keys = args.split()
       
    30         count = len(keys)
       
    31         for n in xrange(len(keys)):
       
    32             argline = self.fin.readline()[:-1]
       
    33             arg, l = argline.split()
       
    34             val = self.fin.read(int(l))
       
    35             if arg not in keys:
       
    36                 raise util.Abort("unexpected parameter %r" % arg)
       
    37             if arg == '*':
       
    38                 star = {}
       
    39                 for n in xrange(int(l)):
       
    40                     arg, l = argline.split()
       
    41                     val = self.fin.read(int(l))
       
    42                     star[arg] = val
       
    43                 data['*'] = star
       
    44             else:
       
    45                 data[arg] = val
       
    46         return [data[k] for k in keys]
       
    47 
       
    48     def getarg(self, name):
       
    49         return self.getargs(name)[0]
       
    50 
       
    51     def getfile(self, fpout):
       
    52         self.sendresponse('')
       
    53         count = int(self.fin.readline())
       
    54         while count:
       
    55             fpout.write(self.fin.read(count))
       
    56             count = int(self.fin.readline())
       
    57 
       
    58     def redirect(self):
       
    59         pass
       
    60 
       
    61     def groupchunks(self, changegroup):
       
    62         while True:
       
    63             d = changegroup.read(4096)
       
    64             if not d:
       
    65                 break
       
    66             yield d
       
    67 
       
    68     def sendresponse(self, v):
       
    69         self.fout.write("%d\n" % len(v))
       
    70         self.fout.write(v)
       
    71         self.fout.flush()
       
    72 
       
    73     def sendstream(self, source):
       
    74         for chunk in source.gen:
       
    75             self.fout.write(chunk)
       
    76         self.fout.flush()
       
    77 
       
    78     def sendpushresponse(self, rsp):
       
    79         self.sendresponse('')
       
    80         self.sendresponse(str(rsp.res))
       
    81 
       
    82     def sendpusherror(self, rsp):
       
    83         self.sendresponse(rsp.res)
       
    84 
       
    85     def serve_forever(self):
       
    86         try:
       
    87             while self.serve_one():
       
    88                 pass
       
    89         finally:
       
    90             if self.lock is not None:
       
    91                 self.lock.release()
       
    92         sys.exit(0)
       
    93 
       
    94     handlers = {
       
    95         str: sendresponse,
       
    96         wireproto.streamres: sendstream,
       
    97         wireproto.pushres: sendpushresponse,
       
    98         wireproto.pusherr: sendpusherror,
       
    99     }
       
   100 
       
   101     def serve_one(self):
       
   102         cmd = self.fin.readline()[:-1]
       
   103         if cmd and cmd in wireproto.commands:
       
   104             rsp = wireproto.dispatch(self.repo, self, cmd)
       
   105             self.handlers[rsp.__class__](self, rsp)
       
   106         elif cmd:
       
   107             impl = getattr(self, 'do_' + cmd, None)
       
   108             if impl:
       
   109                 r = impl()
       
   110                 if r is not None:
       
   111                     self.sendresponse(r)
       
   112             else: self.sendresponse("")
       
   113         return cmd != ''
       
   114 
       
   115     def do_lock(self):
       
   116         '''DEPRECATED - allowing remote client to lock repo is not safe'''
       
   117 
       
   118         self.lock = self.repo.lock()
       
   119         return ""
       
   120 
       
   121     def do_unlock(self):
       
   122         '''DEPRECATED'''
       
   123 
       
   124         if self.lock:
       
   125             self.lock.release()
       
   126         self.lock = None
       
   127         return ""
       
   128 
       
   129     def do_addchangegroup(self):
       
   130         '''DEPRECATED'''
       
   131 
       
   132         if not self.lock:
       
   133             self.sendresponse("not locked")
       
   134             return
       
   135 
       
   136         self.sendresponse("")
       
   137         cg = changegroup.unbundle10(self.fin, "UN")
       
   138         r = self.repo.addchangegroup(cg, 'serve', self._client(),
       
   139                                      lock=self.lock)
       
   140         return str(r)
       
   141 
       
   142     def _client(self):
       
   143         client = os.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
       
   144         return 'remote:ssh:' + client