Mercurial > hg
comparison hgext/zeroconf/Zeroconf.py @ 7071:643c751e60b2
zeroconf: initial implementation
This is a basic, hopefully portable, zeroconf extension.
Enabling it will allow hg paths/pull/push/clone/etc. to automatically
discover services advertised as "_hg".
And naturally, running hg serve will advertise itself as a "_hg"
service as well as a "_http" service for use by browsers.
author | Matt Mackall <mpm@selenic.com> |
---|---|
date | Wed, 08 Oct 2008 19:58:35 -0500 |
parents | |
children | d812029cda85 |
comparison
equal
deleted
inserted
replaced
7070:2627ef59195d | 7071:643c751e60b2 |
---|---|
1 """ Multicast DNS Service Discovery for Python, v0.12 | |
2 Copyright (C) 2003, Paul Scott-Murphy | |
3 | |
4 This module provides a framework for the use of DNS Service Discovery | |
5 using IP multicast. It has been tested against the JRendezvous | |
6 implementation from <a href="http://strangeberry.com">StrangeBerry</a>, | |
7 and against the mDNSResponder from Mac OS X 10.3.8. | |
8 | |
9 This library is free software; you can redistribute it and/or | |
10 modify it under the terms of the GNU Lesser General Public | |
11 License as published by the Free Software Foundation; either | |
12 version 2.1 of the License, or (at your option) any later version. | |
13 | |
14 This library is distributed in the hope that it will be useful, | |
15 but WITHOUT ANY WARRANTY; without even the implied warranty of | |
16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |
17 Lesser General Public License for more details. | |
18 | |
19 You should have received a copy of the GNU Lesser General Public | |
20 License along with this library; if not, write to the Free Software | |
21 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA | |
22 | |
23 """ | |
24 | |
25 """0.12 update - allow selection of binding interface | |
26 typo fix - Thanks A. M. Kuchlingi | |
27 removed all use of word 'Rendezvous' - this is an API change""" | |
28 | |
29 """0.11 update - correction to comments for addListener method | |
30 support for new record types seen from OS X | |
31 - IPv6 address | |
32 - hostinfo | |
33 ignore unknown DNS record types | |
34 fixes to name decoding | |
35 works alongside other processes using port 5353 (e.g. on Mac OS X) | |
36 tested against Mac OS X 10.3.2's mDNSResponder | |
37 corrections to removal of list entries for service browser""" | |
38 | |
39 """0.10 update - Jonathon Paisley contributed these corrections: | |
40 always multicast replies, even when query is unicast | |
41 correct a pointer encoding problem | |
42 can now write records in any order | |
43 traceback shown on failure | |
44 better TXT record parsing | |
45 server is now separate from name | |
46 can cancel a service browser | |
47 | |
48 modified some unit tests to accommodate these changes""" | |
49 | |
50 """0.09 update - remove all records on service unregistration | |
51 fix DOS security problem with readName""" | |
52 | |
53 """0.08 update - changed licensing to LGPL""" | |
54 | |
55 """0.07 update - faster shutdown on engine | |
56 pointer encoding of outgoing names | |
57 ServiceBrowser now works | |
58 new unit tests""" | |
59 | |
60 """0.06 update - small improvements with unit tests | |
61 added defined exception types | |
62 new style objects | |
63 fixed hostname/interface problem | |
64 fixed socket timeout problem | |
65 fixed addServiceListener() typo bug | |
66 using select() for socket reads | |
67 tested on Debian unstable with Python 2.2.2""" | |
68 | |
69 """0.05 update - ensure case insensitivty on domain names | |
70 support for unicast DNS queries""" | |
71 | |
72 """0.04 update - added some unit tests | |
73 added __ne__ adjuncts where required | |
74 ensure names end in '.local.' | |
75 timeout on receiving socket for clean shutdown""" | |
76 | |
77 __author__ = "Paul Scott-Murphy" | |
78 __email__ = "paul at scott dash murphy dot com" | |
79 __version__ = "0.12" | |
80 | |
81 import string | |
82 import time | |
83 import struct | |
84 import socket | |
85 import threading | |
86 import select | |
87 import traceback | |
88 | |
89 __all__ = ["Zeroconf", "ServiceInfo", "ServiceBrowser"] | |
90 | |
91 # hook for threads | |
92 | |
93 globals()['_GLOBAL_DONE'] = 0 | |
94 | |
95 # Some timing constants | |
96 | |
97 _UNREGISTER_TIME = 125 | |
98 _CHECK_TIME = 175 | |
99 _REGISTER_TIME = 225 | |
100 _LISTENER_TIME = 200 | |
101 _BROWSER_TIME = 500 | |
102 | |
103 # Some DNS constants | |
104 | |
105 _MDNS_ADDR = '224.0.0.251' | |
106 _MDNS_PORT = 5353; | |
107 _DNS_PORT = 53; | |
108 _DNS_TTL = 60 * 60; # one hour default TTL | |
109 | |
110 _MAX_MSG_TYPICAL = 1460 # unused | |
111 _MAX_MSG_ABSOLUTE = 8972 | |
112 | |
113 _FLAGS_QR_MASK = 0x8000 # query response mask | |
114 _FLAGS_QR_QUERY = 0x0000 # query | |
115 _FLAGS_QR_RESPONSE = 0x8000 # response | |
116 | |
117 _FLAGS_AA = 0x0400 # Authorative answer | |
118 _FLAGS_TC = 0x0200 # Truncated | |
119 _FLAGS_RD = 0x0100 # Recursion desired | |
120 _FLAGS_RA = 0x8000 # Recursion available | |
121 | |
122 _FLAGS_Z = 0x0040 # Zero | |
123 _FLAGS_AD = 0x0020 # Authentic data | |
124 _FLAGS_CD = 0x0010 # Checking disabled | |
125 | |
126 _CLASS_IN = 1 | |
127 _CLASS_CS = 2 | |
128 _CLASS_CH = 3 | |
129 _CLASS_HS = 4 | |
130 _CLASS_NONE = 254 | |
131 _CLASS_ANY = 255 | |
132 _CLASS_MASK = 0x7FFF | |
133 _CLASS_UNIQUE = 0x8000 | |
134 | |
135 _TYPE_A = 1 | |
136 _TYPE_NS = 2 | |
137 _TYPE_MD = 3 | |
138 _TYPE_MF = 4 | |
139 _TYPE_CNAME = 5 | |
140 _TYPE_SOA = 6 | |
141 _TYPE_MB = 7 | |
142 _TYPE_MG = 8 | |
143 _TYPE_MR = 9 | |
144 _TYPE_NULL = 10 | |
145 _TYPE_WKS = 11 | |
146 _TYPE_PTR = 12 | |
147 _TYPE_HINFO = 13 | |
148 _TYPE_MINFO = 14 | |
149 _TYPE_MX = 15 | |
150 _TYPE_TXT = 16 | |
151 _TYPE_AAAA = 28 | |
152 _TYPE_SRV = 33 | |
153 _TYPE_ANY = 255 | |
154 | |
155 # Mapping constants to names | |
156 | |
157 _CLASSES = { _CLASS_IN : "in", | |
158 _CLASS_CS : "cs", | |
159 _CLASS_CH : "ch", | |
160 _CLASS_HS : "hs", | |
161 _CLASS_NONE : "none", | |
162 _CLASS_ANY : "any" } | |
163 | |
164 _TYPES = { _TYPE_A : "a", | |
165 _TYPE_NS : "ns", | |
166 _TYPE_MD : "md", | |
167 _TYPE_MF : "mf", | |
168 _TYPE_CNAME : "cname", | |
169 _TYPE_SOA : "soa", | |
170 _TYPE_MB : "mb", | |
171 _TYPE_MG : "mg", | |
172 _TYPE_MR : "mr", | |
173 _TYPE_NULL : "null", | |
174 _TYPE_WKS : "wks", | |
175 _TYPE_PTR : "ptr", | |
176 _TYPE_HINFO : "hinfo", | |
177 _TYPE_MINFO : "minfo", | |
178 _TYPE_MX : "mx", | |
179 _TYPE_TXT : "txt", | |
180 _TYPE_AAAA : "quada", | |
181 _TYPE_SRV : "srv", | |
182 _TYPE_ANY : "any" } | |
183 | |
184 # utility functions | |
185 | |
186 def currentTimeMillis(): | |
187 """Current system time in milliseconds""" | |
188 return time.time() * 1000 | |
189 | |
190 # Exceptions | |
191 | |
192 class NonLocalNameException(Exception): | |
193 pass | |
194 | |
195 class NonUniqueNameException(Exception): | |
196 pass | |
197 | |
198 class NamePartTooLongException(Exception): | |
199 pass | |
200 | |
201 class AbstractMethodException(Exception): | |
202 pass | |
203 | |
204 class BadTypeInNameException(Exception): | |
205 pass | |
206 | |
207 # implementation classes | |
208 | |
209 class DNSEntry(object): | |
210 """A DNS entry""" | |
211 | |
212 def __init__(self, name, type, clazz): | |
213 self.key = string.lower(name) | |
214 self.name = name | |
215 self.type = type | |
216 self.clazz = clazz & _CLASS_MASK | |
217 self.unique = (clazz & _CLASS_UNIQUE) != 0 | |
218 | |
219 def __eq__(self, other): | |
220 """Equality test on name, type, and class""" | |
221 if isinstance(other, DNSEntry): | |
222 return self.name == other.name and self.type == other.type and self.clazz == other.clazz | |
223 return 0 | |
224 | |
225 def __ne__(self, other): | |
226 """Non-equality test""" | |
227 return not self.__eq__(other) | |
228 | |
229 def getClazz(self, clazz): | |
230 """Class accessor""" | |
231 try: | |
232 return _CLASSES[clazz] | |
233 except: | |
234 return "?(%s)" % (clazz) | |
235 | |
236 def getType(self, type): | |
237 """Type accessor""" | |
238 try: | |
239 return _TYPES[type] | |
240 except: | |
241 return "?(%s)" % (type) | |
242 | |
243 def toString(self, hdr, other): | |
244 """String representation with additional information""" | |
245 result = "%s[%s,%s" % (hdr, self.getType(self.type), self.getClazz(self.clazz)) | |
246 if self.unique: | |
247 result += "-unique," | |
248 else: | |
249 result += "," | |
250 result += self.name | |
251 if other is not None: | |
252 result += ",%s]" % (other) | |
253 else: | |
254 result += "]" | |
255 return result | |
256 | |
257 class DNSQuestion(DNSEntry): | |
258 """A DNS question entry""" | |
259 | |
260 def __init__(self, name, type, clazz): | |
261 if not name.endswith(".local."): | |
262 raise NonLocalNameException | |
263 DNSEntry.__init__(self, name, type, clazz) | |
264 | |
265 def answeredBy(self, rec): | |
266 """Returns true if the question is answered by the record""" | |
267 return self.clazz == rec.clazz and (self.type == rec.type or self.type == _TYPE_ANY) and self.name == rec.name | |
268 | |
269 def __repr__(self): | |
270 """String representation""" | |
271 return DNSEntry.toString(self, "question", None) | |
272 | |
273 | |
274 class DNSRecord(DNSEntry): | |
275 """A DNS record - like a DNS entry, but has a TTL""" | |
276 | |
277 def __init__(self, name, type, clazz, ttl): | |
278 DNSEntry.__init__(self, name, type, clazz) | |
279 self.ttl = ttl | |
280 self.created = currentTimeMillis() | |
281 | |
282 def __eq__(self, other): | |
283 """Tests equality as per DNSRecord""" | |
284 if isinstance(other, DNSRecord): | |
285 return DNSEntry.__eq__(self, other) | |
286 return 0 | |
287 | |
288 def suppressedBy(self, msg): | |
289 """Returns true if any answer in a message can suffice for the | |
290 information held in this record.""" | |
291 for record in msg.answers: | |
292 if self.suppressedByAnswer(record): | |
293 return 1 | |
294 return 0 | |
295 | |
296 def suppressedByAnswer(self, other): | |
297 """Returns true if another record has same name, type and class, | |
298 and if its TTL is at least half of this record's.""" | |
299 if self == other and other.ttl > (self.ttl / 2): | |
300 return 1 | |
301 return 0 | |
302 | |
303 def getExpirationTime(self, percent): | |
304 """Returns the time at which this record will have expired | |
305 by a certain percentage.""" | |
306 return self.created + (percent * self.ttl * 10) | |
307 | |
308 def getRemainingTTL(self, now): | |
309 """Returns the remaining TTL in seconds.""" | |
310 return max(0, (self.getExpirationTime(100) - now) / 1000) | |
311 | |
312 def isExpired(self, now): | |
313 """Returns true if this record has expired.""" | |
314 return self.getExpirationTime(100) <= now | |
315 | |
316 def isStale(self, now): | |
317 """Returns true if this record is at least half way expired.""" | |
318 return self.getExpirationTime(50) <= now | |
319 | |
320 def resetTTL(self, other): | |
321 """Sets this record's TTL and created time to that of | |
322 another record.""" | |
323 self.created = other.created | |
324 self.ttl = other.ttl | |
325 | |
326 def write(self, out): | |
327 """Abstract method""" | |
328 raise AbstractMethodException | |
329 | |
330 def toString(self, other): | |
331 """String representation with addtional information""" | |
332 arg = "%s/%s,%s" % (self.ttl, self.getRemainingTTL(currentTimeMillis()), other) | |
333 return DNSEntry.toString(self, "record", arg) | |
334 | |
335 class DNSAddress(DNSRecord): | |
336 """A DNS address record""" | |
337 | |
338 def __init__(self, name, type, clazz, ttl, address): | |
339 DNSRecord.__init__(self, name, type, clazz, ttl) | |
340 self.address = address | |
341 | |
342 def write(self, out): | |
343 """Used in constructing an outgoing packet""" | |
344 out.writeString(self.address, len(self.address)) | |
345 | |
346 def __eq__(self, other): | |
347 """Tests equality on address""" | |
348 if isinstance(other, DNSAddress): | |
349 return self.address == other.address | |
350 return 0 | |
351 | |
352 def __repr__(self): | |
353 """String representation""" | |
354 try: | |
355 return socket.inet_ntoa(self.address) | |
356 except: | |
357 return self.address | |
358 | |
359 class DNSHinfo(DNSRecord): | |
360 """A DNS host information record""" | |
361 | |
362 def __init__(self, name, type, clazz, ttl, cpu, os): | |
363 DNSRecord.__init__(self, name, type, clazz, ttl) | |
364 self.cpu = cpu | |
365 self.os = os | |
366 | |
367 def write(self, out): | |
368 """Used in constructing an outgoing packet""" | |
369 out.writeString(self.cpu, len(self.cpu)) | |
370 out.writeString(self.os, len(self.os)) | |
371 | |
372 def __eq__(self, other): | |
373 """Tests equality on cpu and os""" | |
374 if isinstance(other, DNSHinfo): | |
375 return self.cpu == other.cpu and self.os == other.os | |
376 return 0 | |
377 | |
378 def __repr__(self): | |
379 """String representation""" | |
380 return self.cpu + " " + self.os | |
381 | |
382 class DNSPointer(DNSRecord): | |
383 """A DNS pointer record""" | |
384 | |
385 def __init__(self, name, type, clazz, ttl, alias): | |
386 DNSRecord.__init__(self, name, type, clazz, ttl) | |
387 self.alias = alias | |
388 | |
389 def write(self, out): | |
390 """Used in constructing an outgoing packet""" | |
391 out.writeName(self.alias) | |
392 | |
393 def __eq__(self, other): | |
394 """Tests equality on alias""" | |
395 if isinstance(other, DNSPointer): | |
396 return self.alias == other.alias | |
397 return 0 | |
398 | |
399 def __repr__(self): | |
400 """String representation""" | |
401 return self.toString(self.alias) | |
402 | |
403 class DNSText(DNSRecord): | |
404 """A DNS text record""" | |
405 | |
406 def __init__(self, name, type, clazz, ttl, text): | |
407 DNSRecord.__init__(self, name, type, clazz, ttl) | |
408 self.text = text | |
409 | |
410 def write(self, out): | |
411 """Used in constructing an outgoing packet""" | |
412 out.writeString(self.text, len(self.text)) | |
413 | |
414 def __eq__(self, other): | |
415 """Tests equality on text""" | |
416 if isinstance(other, DNSText): | |
417 return self.text == other.text | |
418 return 0 | |
419 | |
420 def __repr__(self): | |
421 """String representation""" | |
422 if len(self.text) > 10: | |
423 return self.toString(self.text[:7] + "...") | |
424 else: | |
425 return self.toString(self.text) | |
426 | |
427 class DNSService(DNSRecord): | |
428 """A DNS service record""" | |
429 | |
430 def __init__(self, name, type, clazz, ttl, priority, weight, port, server): | |
431 DNSRecord.__init__(self, name, type, clazz, ttl) | |
432 self.priority = priority | |
433 self.weight = weight | |
434 self.port = port | |
435 self.server = server | |
436 | |
437 def write(self, out): | |
438 """Used in constructing an outgoing packet""" | |
439 out.writeShort(self.priority) | |
440 out.writeShort(self.weight) | |
441 out.writeShort(self.port) | |
442 out.writeName(self.server) | |
443 | |
444 def __eq__(self, other): | |
445 """Tests equality on priority, weight, port and server""" | |
446 if isinstance(other, DNSService): | |
447 return self.priority == other.priority and self.weight == other.weight and self.port == other.port and self.server == other.server | |
448 return 0 | |
449 | |
450 def __repr__(self): | |
451 """String representation""" | |
452 return self.toString("%s:%s" % (self.server, self.port)) | |
453 | |
454 class DNSIncoming(object): | |
455 """Object representation of an incoming DNS packet""" | |
456 | |
457 def __init__(self, data): | |
458 """Constructor from string holding bytes of packet""" | |
459 self.offset = 0 | |
460 self.data = data | |
461 self.questions = [] | |
462 self.answers = [] | |
463 self.numQuestions = 0 | |
464 self.numAnswers = 0 | |
465 self.numAuthorities = 0 | |
466 self.numAdditionals = 0 | |
467 | |
468 self.readHeader() | |
469 self.readQuestions() | |
470 self.readOthers() | |
471 | |
472 def readHeader(self): | |
473 """Reads header portion of packet""" | |
474 format = '!HHHHHH' | |
475 length = struct.calcsize(format) | |
476 info = struct.unpack(format, self.data[self.offset:self.offset+length]) | |
477 self.offset += length | |
478 | |
479 self.id = info[0] | |
480 self.flags = info[1] | |
481 self.numQuestions = info[2] | |
482 self.numAnswers = info[3] | |
483 self.numAuthorities = info[4] | |
484 self.numAdditionals = info[5] | |
485 | |
486 def readQuestions(self): | |
487 """Reads questions section of packet""" | |
488 format = '!HH' | |
489 length = struct.calcsize(format) | |
490 for i in range(0, self.numQuestions): | |
491 name = self.readName() | |
492 info = struct.unpack(format, self.data[self.offset:self.offset+length]) | |
493 self.offset += length | |
494 | |
495 question = DNSQuestion(name, info[0], info[1]) | |
496 self.questions.append(question) | |
497 | |
498 def readInt(self): | |
499 """Reads an integer from the packet""" | |
500 format = '!I' | |
501 length = struct.calcsize(format) | |
502 info = struct.unpack(format, self.data[self.offset:self.offset+length]) | |
503 self.offset += length | |
504 return info[0] | |
505 | |
506 def readCharacterString(self): | |
507 """Reads a character string from the packet""" | |
508 length = ord(self.data[self.offset]) | |
509 self.offset += 1 | |
510 return self.readString(length) | |
511 | |
512 def readString(self, len): | |
513 """Reads a string of a given length from the packet""" | |
514 format = '!' + str(len) + 's' | |
515 length = struct.calcsize(format) | |
516 info = struct.unpack(format, self.data[self.offset:self.offset+length]) | |
517 self.offset += length | |
518 return info[0] | |
519 | |
520 def readUnsignedShort(self): | |
521 """Reads an unsigned short from the packet""" | |
522 format = '!H' | |
523 length = struct.calcsize(format) | |
524 info = struct.unpack(format, self.data[self.offset:self.offset+length]) | |
525 self.offset += length | |
526 return info[0] | |
527 | |
528 def readOthers(self): | |
529 """Reads the answers, authorities and additionals section of the packet""" | |
530 format = '!HHiH' | |
531 length = struct.calcsize(format) | |
532 n = self.numAnswers + self.numAuthorities + self.numAdditionals | |
533 for i in range(0, n): | |
534 domain = self.readName() | |
535 info = struct.unpack(format, self.data[self.offset:self.offset+length]) | |
536 self.offset += length | |
537 | |
538 rec = None | |
539 if info[0] == _TYPE_A: | |
540 rec = DNSAddress(domain, info[0], info[1], info[2], self.readString(4)) | |
541 elif info[0] == _TYPE_CNAME or info[0] == _TYPE_PTR: | |
542 rec = DNSPointer(domain, info[0], info[1], info[2], self.readName()) | |
543 elif info[0] == _TYPE_TXT: | |
544 rec = DNSText(domain, info[0], info[1], info[2], self.readString(info[3])) | |
545 elif info[0] == _TYPE_SRV: | |
546 rec = DNSService(domain, info[0], info[1], info[2], self.readUnsignedShort(), self.readUnsignedShort(), self.readUnsignedShort(), self.readName()) | |
547 elif info[0] == _TYPE_HINFO: | |
548 rec = DNSHinfo(domain, info[0], info[1], info[2], self.readCharacterString(), self.readCharacterString()) | |
549 elif info[0] == _TYPE_AAAA: | |
550 rec = DNSAddress(domain, info[0], info[1], info[2], self.readString(16)) | |
551 else: | |
552 # Try to ignore types we don't know about | |
553 # this may mean the rest of the name is | |
554 # unable to be parsed, and may show errors | |
555 # so this is left for debugging. New types | |
556 # encountered need to be parsed properly. | |
557 # | |
558 #print "UNKNOWN TYPE = " + str(info[0]) | |
559 #raise BadTypeInNameException | |
560 pass | |
561 | |
562 if rec is not None: | |
563 self.answers.append(rec) | |
564 | |
565 def isQuery(self): | |
566 """Returns true if this is a query""" | |
567 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY | |
568 | |
569 def isResponse(self): | |
570 """Returns true if this is a response""" | |
571 return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE | |
572 | |
573 def readUTF(self, offset, len): | |
574 """Reads a UTF-8 string of a given length from the packet""" | |
575 result = self.data[offset:offset+len].decode('utf-8') | |
576 return result | |
577 | |
578 def readName(self): | |
579 """Reads a domain name from the packet""" | |
580 result = '' | |
581 off = self.offset | |
582 next = -1 | |
583 first = off | |
584 | |
585 while 1: | |
586 len = ord(self.data[off]) | |
587 off += 1 | |
588 if len == 0: | |
589 break | |
590 t = len & 0xC0 | |
591 if t == 0x00: | |
592 result = ''.join((result, self.readUTF(off, len) + '.')) | |
593 off += len | |
594 elif t == 0xC0: | |
595 if next < 0: | |
596 next = off + 1 | |
597 off = ((len & 0x3F) << 8) | ord(self.data[off]) | |
598 if off >= first: | |
599 raise "Bad domain name (circular) at " + str(off) | |
600 first = off | |
601 else: | |
602 raise "Bad domain name at " + str(off) | |
603 | |
604 if next >= 0: | |
605 self.offset = next | |
606 else: | |
607 self.offset = off | |
608 | |
609 return result | |
610 | |
611 | |
612 class DNSOutgoing(object): | |
613 """Object representation of an outgoing packet""" | |
614 | |
615 def __init__(self, flags, multicast = 1): | |
616 self.finished = 0 | |
617 self.id = 0 | |
618 self.multicast = multicast | |
619 self.flags = flags | |
620 self.names = {} | |
621 self.data = [] | |
622 self.size = 12 | |
623 | |
624 self.questions = [] | |
625 self.answers = [] | |
626 self.authorities = [] | |
627 self.additionals = [] | |
628 | |
629 def addQuestion(self, record): | |
630 """Adds a question""" | |
631 self.questions.append(record) | |
632 | |
633 def addAnswer(self, inp, record): | |
634 """Adds an answer""" | |
635 if not record.suppressedBy(inp): | |
636 self.addAnswerAtTime(record, 0) | |
637 | |
638 def addAnswerAtTime(self, record, now): | |
639 """Adds an answer if if does not expire by a certain time""" | |
640 if record is not None: | |
641 if now == 0 or not record.isExpired(now): | |
642 self.answers.append((record, now)) | |
643 | |
644 def addAuthorativeAnswer(self, record): | |
645 """Adds an authoritative answer""" | |
646 self.authorities.append(record) | |
647 | |
648 def addAdditionalAnswer(self, record): | |
649 """Adds an additional answer""" | |
650 self.additionals.append(record) | |
651 | |
652 def writeByte(self, value): | |
653 """Writes a single byte to the packet""" | |
654 format = '!c' | |
655 self.data.append(struct.pack(format, chr(value))) | |
656 self.size += 1 | |
657 | |
658 def insertShort(self, index, value): | |
659 """Inserts an unsigned short in a certain position in the packet""" | |
660 format = '!H' | |
661 self.data.insert(index, struct.pack(format, value)) | |
662 self.size += 2 | |
663 | |
664 def writeShort(self, value): | |
665 """Writes an unsigned short to the packet""" | |
666 format = '!H' | |
667 self.data.append(struct.pack(format, value)) | |
668 self.size += 2 | |
669 | |
670 def writeInt(self, value): | |
671 """Writes an unsigned integer to the packet""" | |
672 format = '!I' | |
673 self.data.append(struct.pack(format, int(value))) | |
674 self.size += 4 | |
675 | |
676 def writeString(self, value, length): | |
677 """Writes a string to the packet""" | |
678 format = '!' + str(length) + 's' | |
679 self.data.append(struct.pack(format, value)) | |
680 self.size += length | |
681 | |
682 def writeUTF(self, s): | |
683 """Writes a UTF-8 string of a given length to the packet""" | |
684 utfstr = s.encode('utf-8') | |
685 length = len(utfstr) | |
686 if length > 64: | |
687 raise NamePartTooLongException | |
688 self.writeByte(length) | |
689 self.writeString(utfstr, length) | |
690 | |
691 def writeName(self, name): | |
692 """Writes a domain name to the packet""" | |
693 | |
694 try: | |
695 # Find existing instance of this name in packet | |
696 # | |
697 index = self.names[name] | |
698 except KeyError: | |
699 # No record of this name already, so write it | |
700 # out as normal, recording the location of the name | |
701 # for future pointers to it. | |
702 # | |
703 self.names[name] = self.size | |
704 parts = name.split('.') | |
705 if parts[-1] == '': | |
706 parts = parts[:-1] | |
707 for part in parts: | |
708 self.writeUTF(part) | |
709 self.writeByte(0) | |
710 return | |
711 | |
712 # An index was found, so write a pointer to it | |
713 # | |
714 self.writeByte((index >> 8) | 0xC0) | |
715 self.writeByte(index) | |
716 | |
717 def writeQuestion(self, question): | |
718 """Writes a question to the packet""" | |
719 self.writeName(question.name) | |
720 self.writeShort(question.type) | |
721 self.writeShort(question.clazz) | |
722 | |
723 def writeRecord(self, record, now): | |
724 """Writes a record (answer, authoritative answer, additional) to | |
725 the packet""" | |
726 self.writeName(record.name) | |
727 self.writeShort(record.type) | |
728 if record.unique and self.multicast: | |
729 self.writeShort(record.clazz | _CLASS_UNIQUE) | |
730 else: | |
731 self.writeShort(record.clazz) | |
732 if now == 0: | |
733 self.writeInt(record.ttl) | |
734 else: | |
735 self.writeInt(record.getRemainingTTL(now)) | |
736 index = len(self.data) | |
737 # Adjust size for the short we will write before this record | |
738 # | |
739 self.size += 2 | |
740 record.write(self) | |
741 self.size -= 2 | |
742 | |
743 length = len(''.join(self.data[index:])) | |
744 self.insertShort(index, length) # Here is the short we adjusted for | |
745 | |
746 def packet(self): | |
747 """Returns a string containing the packet's bytes | |
748 | |
749 No further parts should be added to the packet once this | |
750 is done.""" | |
751 if not self.finished: | |
752 self.finished = 1 | |
753 for question in self.questions: | |
754 self.writeQuestion(question) | |
755 for answer, time in self.answers: | |
756 self.writeRecord(answer, time) | |
757 for authority in self.authorities: | |
758 self.writeRecord(authority, 0) | |
759 for additional in self.additionals: | |
760 self.writeRecord(additional, 0) | |
761 | |
762 self.insertShort(0, len(self.additionals)) | |
763 self.insertShort(0, len(self.authorities)) | |
764 self.insertShort(0, len(self.answers)) | |
765 self.insertShort(0, len(self.questions)) | |
766 self.insertShort(0, self.flags) | |
767 if self.multicast: | |
768 self.insertShort(0, 0) | |
769 else: | |
770 self.insertShort(0, self.id) | |
771 return ''.join(self.data) | |
772 | |
773 | |
774 class DNSCache(object): | |
775 """A cache of DNS entries""" | |
776 | |
777 def __init__(self): | |
778 self.cache = {} | |
779 | |
780 def add(self, entry): | |
781 """Adds an entry""" | |
782 try: | |
783 list = self.cache[entry.key] | |
784 except: | |
785 list = self.cache[entry.key] = [] | |
786 list.append(entry) | |
787 | |
788 def remove(self, entry): | |
789 """Removes an entry""" | |
790 try: | |
791 list = self.cache[entry.key] | |
792 list.remove(entry) | |
793 except: | |
794 pass | |
795 | |
796 def get(self, entry): | |
797 """Gets an entry by key. Will return None if there is no | |
798 matching entry.""" | |
799 try: | |
800 list = self.cache[entry.key] | |
801 return list[list.index(entry)] | |
802 except: | |
803 return None | |
804 | |
805 def getByDetails(self, name, type, clazz): | |
806 """Gets an entry by details. Will return None if there is | |
807 no matching entry.""" | |
808 entry = DNSEntry(name, type, clazz) | |
809 return self.get(entry) | |
810 | |
811 def entriesWithName(self, name): | |
812 """Returns a list of entries whose key matches the name.""" | |
813 try: | |
814 return self.cache[name] | |
815 except: | |
816 return [] | |
817 | |
818 def entries(self): | |
819 """Returns a list of all entries""" | |
820 def add(x, y): return x+y | |
821 try: | |
822 return reduce(add, self.cache.values()) | |
823 except: | |
824 return [] | |
825 | |
826 | |
827 class Engine(threading.Thread): | |
828 """An engine wraps read access to sockets, allowing objects that | |
829 need to receive data from sockets to be called back when the | |
830 sockets are ready. | |
831 | |
832 A reader needs a handle_read() method, which is called when the socket | |
833 it is interested in is ready for reading. | |
834 | |
835 Writers are not implemented here, because we only send short | |
836 packets. | |
837 """ | |
838 | |
839 def __init__(self, zeroconf): | |
840 threading.Thread.__init__(self) | |
841 self.zeroconf = zeroconf | |
842 self.readers = {} # maps socket to reader | |
843 self.timeout = 5 | |
844 self.condition = threading.Condition() | |
845 self.start() | |
846 | |
847 def run(self): | |
848 while not globals()['_GLOBAL_DONE']: | |
849 rs = self.getReaders() | |
850 if len(rs) == 0: | |
851 # No sockets to manage, but we wait for the timeout | |
852 # or addition of a socket | |
853 # | |
854 self.condition.acquire() | |
855 self.condition.wait(self.timeout) | |
856 self.condition.release() | |
857 else: | |
858 try: | |
859 rr, wr, er = select.select(rs, [], [], self.timeout) | |
860 for socket in rr: | |
861 try: | |
862 self.readers[socket].handle_read() | |
863 except: | |
864 traceback.print_exc() | |
865 except: | |
866 pass | |
867 | |
868 def getReaders(self): | |
869 result = [] | |
870 self.condition.acquire() | |
871 result = self.readers.keys() | |
872 self.condition.release() | |
873 return result | |
874 | |
875 def addReader(self, reader, socket): | |
876 self.condition.acquire() | |
877 self.readers[socket] = reader | |
878 self.condition.notify() | |
879 self.condition.release() | |
880 | |
881 def delReader(self, socket): | |
882 self.condition.acquire() | |
883 del(self.readers[socket]) | |
884 self.condition.notify() | |
885 self.condition.release() | |
886 | |
887 def notify(self): | |
888 self.condition.acquire() | |
889 self.condition.notify() | |
890 self.condition.release() | |
891 | |
892 class Listener(object): | |
893 """A Listener is used by this module to listen on the multicast | |
894 group to which DNS messages are sent, allowing the implementation | |
895 to cache information as it arrives. | |
896 | |
897 It requires registration with an Engine object in order to have | |
898 the read() method called when a socket is availble for reading.""" | |
899 | |
900 def __init__(self, zeroconf): | |
901 self.zeroconf = zeroconf | |
902 self.zeroconf.engine.addReader(self, self.zeroconf.socket) | |
903 | |
904 def handle_read(self): | |
905 data, (addr, port) = self.zeroconf.socket.recvfrom(_MAX_MSG_ABSOLUTE) | |
906 self.data = data | |
907 msg = DNSIncoming(data) | |
908 if msg.isQuery(): | |
909 # Always multicast responses | |
910 # | |
911 if port == _MDNS_PORT: | |
912 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT) | |
913 # If it's not a multicast query, reply via unicast | |
914 # and multicast | |
915 # | |
916 elif port == _DNS_PORT: | |
917 self.zeroconf.handleQuery(msg, addr, port) | |
918 self.zeroconf.handleQuery(msg, _MDNS_ADDR, _MDNS_PORT) | |
919 else: | |
920 self.zeroconf.handleResponse(msg) | |
921 | |
922 | |
923 class Reaper(threading.Thread): | |
924 """A Reaper is used by this module to remove cache entries that | |
925 have expired.""" | |
926 | |
927 def __init__(self, zeroconf): | |
928 threading.Thread.__init__(self) | |
929 self.zeroconf = zeroconf | |
930 self.start() | |
931 | |
932 def run(self): | |
933 while 1: | |
934 self.zeroconf.wait(10 * 1000) | |
935 if globals()['_GLOBAL_DONE']: | |
936 return | |
937 now = currentTimeMillis() | |
938 for record in self.zeroconf.cache.entries(): | |
939 if record.isExpired(now): | |
940 self.zeroconf.updateRecord(now, record) | |
941 self.zeroconf.cache.remove(record) | |
942 | |
943 | |
944 class ServiceBrowser(threading.Thread): | |
945 """Used to browse for a service of a specific type. | |
946 | |
947 The listener object will have its addService() and | |
948 removeService() methods called when this browser | |
949 discovers changes in the services availability.""" | |
950 | |
951 def __init__(self, zeroconf, type, listener): | |
952 """Creates a browser for a specific type""" | |
953 threading.Thread.__init__(self) | |
954 self.zeroconf = zeroconf | |
955 self.type = type | |
956 self.listener = listener | |
957 self.services = {} | |
958 self.nextTime = currentTimeMillis() | |
959 self.delay = _BROWSER_TIME | |
960 self.list = [] | |
961 | |
962 self.done = 0 | |
963 | |
964 self.zeroconf.addListener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN)) | |
965 self.start() | |
966 | |
967 def updateRecord(self, zeroconf, now, record): | |
968 """Callback invoked by Zeroconf when new information arrives. | |
969 | |
970 Updates information required by browser in the Zeroconf cache.""" | |
971 if record.type == _TYPE_PTR and record.name == self.type: | |
972 expired = record.isExpired(now) | |
973 try: | |
974 oldrecord = self.services[record.alias.lower()] | |
975 if not expired: | |
976 oldrecord.resetTTL(record) | |
977 else: | |
978 del(self.services[record.alias.lower()]) | |
979 callback = lambda x: self.listener.removeService(x, self.type, record.alias) | |
980 self.list.append(callback) | |
981 return | |
982 except: | |
983 if not expired: | |
984 self.services[record.alias.lower()] = record | |
985 callback = lambda x: self.listener.addService(x, self.type, record.alias) | |
986 self.list.append(callback) | |
987 | |
988 expires = record.getExpirationTime(75) | |
989 if expires < self.nextTime: | |
990 self.nextTime = expires | |
991 | |
992 def cancel(self): | |
993 self.done = 1 | |
994 self.zeroconf.notifyAll() | |
995 | |
996 def run(self): | |
997 while 1: | |
998 event = None | |
999 now = currentTimeMillis() | |
1000 if len(self.list) == 0 and self.nextTime > now: | |
1001 self.zeroconf.wait(self.nextTime - now) | |
1002 if globals()['_GLOBAL_DONE'] or self.done: | |
1003 return | |
1004 now = currentTimeMillis() | |
1005 | |
1006 if self.nextTime <= now: | |
1007 out = DNSOutgoing(_FLAGS_QR_QUERY) | |
1008 out.addQuestion(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN)) | |
1009 for record in self.services.values(): | |
1010 if not record.isExpired(now): | |
1011 out.addAnswerAtTime(record, now) | |
1012 self.zeroconf.send(out) | |
1013 self.nextTime = now + self.delay | |
1014 self.delay = min(20 * 1000, self.delay * 2) | |
1015 | |
1016 if len(self.list) > 0: | |
1017 event = self.list.pop(0) | |
1018 | |
1019 if event is not None: | |
1020 event(self.zeroconf) | |
1021 | |
1022 | |
1023 class ServiceInfo(object): | |
1024 """Service information""" | |
1025 | |
1026 def __init__(self, type, name, address=None, port=None, weight=0, priority=0, properties=None, server=None): | |
1027 """Create a service description. | |
1028 | |
1029 type: fully qualified service type name | |
1030 name: fully qualified service name | |
1031 address: IP address as unsigned short, network byte order | |
1032 port: port that the service runs on | |
1033 weight: weight of the service | |
1034 priority: priority of the service | |
1035 properties: dictionary of properties (or a string holding the bytes for the text field) | |
1036 server: fully qualified name for service host (defaults to name)""" | |
1037 | |
1038 if not name.endswith(type): | |
1039 raise BadTypeInNameException | |
1040 self.type = type | |
1041 self.name = name | |
1042 self.address = address | |
1043 self.port = port | |
1044 self.weight = weight | |
1045 self.priority = priority | |
1046 if server: | |
1047 self.server = server | |
1048 else: | |
1049 self.server = name | |
1050 self.setProperties(properties) | |
1051 | |
1052 def setProperties(self, properties): | |
1053 """Sets properties and text of this info from a dictionary""" | |
1054 if isinstance(properties, dict): | |
1055 self.properties = properties | |
1056 list = [] | |
1057 result = '' | |
1058 for key in properties: | |
1059 value = properties[key] | |
1060 if value is None: | |
1061 suffix = ''.encode('utf-8') | |
1062 elif isinstance(value, str): | |
1063 suffix = value.encode('utf-8') | |
1064 elif isinstance(value, int): | |
1065 if value: | |
1066 suffix = 'true' | |
1067 else: | |
1068 suffix = 'false' | |
1069 else: | |
1070 suffix = ''.encode('utf-8') | |
1071 list.append('='.join((key, suffix))) | |
1072 for item in list: | |
1073 result = ''.join((result, struct.pack('!c', chr(len(item))), item)) | |
1074 self.text = result | |
1075 else: | |
1076 self.text = properties | |
1077 | |
1078 def setText(self, text): | |
1079 """Sets properties and text given a text field""" | |
1080 self.text = text | |
1081 try: | |
1082 result = {} | |
1083 end = len(text) | |
1084 index = 0 | |
1085 strs = [] | |
1086 while index < end: | |
1087 length = ord(text[index]) | |
1088 index += 1 | |
1089 strs.append(text[index:index+length]) | |
1090 index += length | |
1091 | |
1092 for s in strs: | |
1093 eindex = s.find('=') | |
1094 if eindex == -1: | |
1095 # No equals sign at all | |
1096 key = s | |
1097 value = 0 | |
1098 else: | |
1099 key = s[:eindex] | |
1100 value = s[eindex+1:] | |
1101 if value == 'true': | |
1102 value = 1 | |
1103 elif value == 'false' or not value: | |
1104 value = 0 | |
1105 | |
1106 # Only update non-existent properties | |
1107 if key and result.get(key) == None: | |
1108 result[key] = value | |
1109 | |
1110 self.properties = result | |
1111 except: | |
1112 traceback.print_exc() | |
1113 self.properties = None | |
1114 | |
1115 def getType(self): | |
1116 """Type accessor""" | |
1117 return self.type | |
1118 | |
1119 def getName(self): | |
1120 """Name accessor""" | |
1121 if self.type is not None and self.name.endswith("." + self.type): | |
1122 return self.name[:len(self.name) - len(self.type) - 1] | |
1123 return self.name | |
1124 | |
1125 def getAddress(self): | |
1126 """Address accessor""" | |
1127 return self.address | |
1128 | |
1129 def getPort(self): | |
1130 """Port accessor""" | |
1131 return self.port | |
1132 | |
1133 def getPriority(self): | |
1134 """Pirority accessor""" | |
1135 return self.priority | |
1136 | |
1137 def getWeight(self): | |
1138 """Weight accessor""" | |
1139 return self.weight | |
1140 | |
1141 def getProperties(self): | |
1142 """Properties accessor""" | |
1143 return self.properties | |
1144 | |
1145 def getText(self): | |
1146 """Text accessor""" | |
1147 return self.text | |
1148 | |
1149 def getServer(self): | |
1150 """Server accessor""" | |
1151 return self.server | |
1152 | |
1153 def updateRecord(self, zeroconf, now, record): | |
1154 """Updates service information from a DNS record""" | |
1155 if record is not None and not record.isExpired(now): | |
1156 if record.type == _TYPE_A: | |
1157 #if record.name == self.name: | |
1158 if record.name == self.server: | |
1159 self.address = record.address | |
1160 elif record.type == _TYPE_SRV: | |
1161 if record.name == self.name: | |
1162 self.server = record.server | |
1163 self.port = record.port | |
1164 self.weight = record.weight | |
1165 self.priority = record.priority | |
1166 #self.address = None | |
1167 self.updateRecord(zeroconf, now, zeroconf.cache.getByDetails(self.server, _TYPE_A, _CLASS_IN)) | |
1168 elif record.type == _TYPE_TXT: | |
1169 if record.name == self.name: | |
1170 self.setText(record.text) | |
1171 | |
1172 def request(self, zeroconf, timeout): | |
1173 """Returns true if the service could be discovered on the | |
1174 network, and updates this object with details discovered. | |
1175 """ | |
1176 now = currentTimeMillis() | |
1177 delay = _LISTENER_TIME | |
1178 next = now + delay | |
1179 last = now + timeout | |
1180 result = 0 | |
1181 try: | |
1182 zeroconf.addListener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN)) | |
1183 while self.server is None or self.address is None or self.text is None: | |
1184 if last <= now: | |
1185 return 0 | |
1186 if next <= now: | |
1187 out = DNSOutgoing(_FLAGS_QR_QUERY) | |
1188 out.addQuestion(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN)) | |
1189 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.name, _TYPE_SRV, _CLASS_IN), now) | |
1190 out.addQuestion(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN)) | |
1191 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.name, _TYPE_TXT, _CLASS_IN), now) | |
1192 if self.server is not None: | |
1193 out.addQuestion(DNSQuestion(self.server, _TYPE_A, _CLASS_IN)) | |
1194 out.addAnswerAtTime(zeroconf.cache.getByDetails(self.server, _TYPE_A, _CLASS_IN), now) | |
1195 zeroconf.send(out) | |
1196 next = now + delay | |
1197 delay = delay * 2 | |
1198 | |
1199 zeroconf.wait(min(next, last) - now) | |
1200 now = currentTimeMillis() | |
1201 result = 1 | |
1202 finally: | |
1203 zeroconf.removeListener(self) | |
1204 | |
1205 return result | |
1206 | |
1207 def __eq__(self, other): | |
1208 """Tests equality of service name""" | |
1209 if isinstance(other, ServiceInfo): | |
1210 return other.name == self.name | |
1211 return 0 | |
1212 | |
1213 def __ne__(self, other): | |
1214 """Non-equality test""" | |
1215 return not self.__eq__(other) | |
1216 | |
1217 def __repr__(self): | |
1218 """String representation""" | |
1219 result = "service[%s,%s:%s," % (self.name, socket.inet_ntoa(self.getAddress()), self.port) | |
1220 if self.text is None: | |
1221 result += "None" | |
1222 else: | |
1223 if len(self.text) < 20: | |
1224 result += self.text | |
1225 else: | |
1226 result += self.text[:17] + "..." | |
1227 result += "]" | |
1228 return result | |
1229 | |
1230 | |
1231 class Zeroconf(object): | |
1232 """Implementation of Zeroconf Multicast DNS Service Discovery | |
1233 | |
1234 Supports registration, unregistration, queries and browsing. | |
1235 """ | |
1236 def __init__(self, bindaddress=None): | |
1237 """Creates an instance of the Zeroconf class, establishing | |
1238 multicast communications, listening and reaping threads.""" | |
1239 globals()['_GLOBAL_DONE'] = 0 | |
1240 if bindaddress is None: | |
1241 self.intf = socket.gethostbyname(socket.gethostname()) | |
1242 else: | |
1243 self.intf = bindaddress | |
1244 self.group = ('', _MDNS_PORT) | |
1245 self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
1246 try: | |
1247 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
1248 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
1249 except: | |
1250 # SO_REUSEADDR should be equivalent to SO_REUSEPORT for | |
1251 # multicast UDP sockets (p 731, "TCP/IP Illustrated, | |
1252 # Volume 2"), but some BSD-derived systems require | |
1253 # SO_REUSEPORT to be specified explicity. Also, not all | |
1254 # versions of Python have SO_REUSEPORT available. So | |
1255 # if you're on a BSD-based system, and haven't upgraded | |
1256 # to Python 2.3 yet, you may find this library doesn't | |
1257 # work as expected. | |
1258 # | |
1259 pass | |
1260 self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 255) | |
1261 self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1) | |
1262 try: | |
1263 self.socket.bind(self.group) | |
1264 except: | |
1265 # Some versions of linux raise an exception even though | |
1266 # the SO_REUSE* options have been set, so ignore it | |
1267 # | |
1268 pass | |
1269 #self.socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(self.intf) + socket.inet_aton('0.0.0.0')) | |
1270 self.socket.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0')) | |
1271 | |
1272 self.listeners = [] | |
1273 self.browsers = [] | |
1274 self.services = {} | |
1275 self.servicetypes = {} | |
1276 | |
1277 self.cache = DNSCache() | |
1278 | |
1279 self.condition = threading.Condition() | |
1280 | |
1281 self.engine = Engine(self) | |
1282 self.listener = Listener(self) | |
1283 self.reaper = Reaper(self) | |
1284 | |
1285 def isLoopback(self): | |
1286 return self.intf.startswith("127.0.0.1") | |
1287 | |
1288 def isLinklocal(self): | |
1289 return self.intf.startswith("169.254.") | |
1290 | |
1291 def wait(self, timeout): | |
1292 """Calling thread waits for a given number of milliseconds or | |
1293 until notified.""" | |
1294 self.condition.acquire() | |
1295 self.condition.wait(timeout/1000) | |
1296 self.condition.release() | |
1297 | |
1298 def notifyAll(self): | |
1299 """Notifies all waiting threads""" | |
1300 self.condition.acquire() | |
1301 self.condition.notifyAll() | |
1302 self.condition.release() | |
1303 | |
1304 def getServiceInfo(self, type, name, timeout=3000): | |
1305 """Returns network's service information for a particular | |
1306 name and type, or None if no service matches by the timeout, | |
1307 which defaults to 3 seconds.""" | |
1308 info = ServiceInfo(type, name) | |
1309 if info.request(self, timeout): | |
1310 return info | |
1311 return None | |
1312 | |
1313 def addServiceListener(self, type, listener): | |
1314 """Adds a listener for a particular service type. This object | |
1315 will then have its updateRecord method called when information | |
1316 arrives for that type.""" | |
1317 self.removeServiceListener(listener) | |
1318 self.browsers.append(ServiceBrowser(self, type, listener)) | |
1319 | |
1320 def removeServiceListener(self, listener): | |
1321 """Removes a listener from the set that is currently listening.""" | |
1322 for browser in self.browsers: | |
1323 if browser.listener == listener: | |
1324 browser.cancel() | |
1325 del(browser) | |
1326 | |
1327 def registerService(self, info, ttl=_DNS_TTL): | |
1328 """Registers service information to the network with a default TTL | |
1329 of 60 seconds. Zeroconf will then respond to requests for | |
1330 information for that service. The name of the service may be | |
1331 changed if needed to make it unique on the network.""" | |
1332 self.checkService(info) | |
1333 self.services[info.name.lower()] = info | |
1334 if self.servicetypes.has_key(info.type): | |
1335 self.servicetypes[info.type]+=1 | |
1336 else: | |
1337 self.servicetypes[info.type]=1 | |
1338 now = currentTimeMillis() | |
1339 nextTime = now | |
1340 i = 0 | |
1341 while i < 3: | |
1342 if now < nextTime: | |
1343 self.wait(nextTime - now) | |
1344 now = currentTimeMillis() | |
1345 continue | |
1346 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) | |
1347 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, ttl, info.name), 0) | |
1348 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, ttl, info.priority, info.weight, info.port, info.server), 0) | |
1349 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, ttl, info.text), 0) | |
1350 if info.address: | |
1351 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, ttl, info.address), 0) | |
1352 self.send(out) | |
1353 i += 1 | |
1354 nextTime += _REGISTER_TIME | |
1355 | |
1356 def unregisterService(self, info): | |
1357 """Unregister a service.""" | |
1358 try: | |
1359 del(self.services[info.name.lower()]) | |
1360 if self.servicetypes[info.type]>1: | |
1361 self.servicetypes[info.type]-=1 | |
1362 else: | |
1363 del self.servicetypes[info.type] | |
1364 except: | |
1365 pass | |
1366 now = currentTimeMillis() | |
1367 nextTime = now | |
1368 i = 0 | |
1369 while i < 3: | |
1370 if now < nextTime: | |
1371 self.wait(nextTime - now) | |
1372 now = currentTimeMillis() | |
1373 continue | |
1374 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) | |
1375 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0) | |
1376 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0, info.priority, info.weight, info.port, info.name), 0) | |
1377 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) | |
1378 if info.address: | |
1379 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0) | |
1380 self.send(out) | |
1381 i += 1 | |
1382 nextTime += _UNREGISTER_TIME | |
1383 | |
1384 def unregisterAllServices(self): | |
1385 """Unregister all registered services.""" | |
1386 if len(self.services) > 0: | |
1387 now = currentTimeMillis() | |
1388 nextTime = now | |
1389 i = 0 | |
1390 while i < 3: | |
1391 if now < nextTime: | |
1392 self.wait(nextTime - now) | |
1393 now = currentTimeMillis() | |
1394 continue | |
1395 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) | |
1396 for info in self.services.values(): | |
1397 out.addAnswerAtTime(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0) | |
1398 out.addAnswerAtTime(DNSService(info.name, _TYPE_SRV, _CLASS_IN, 0, info.priority, info.weight, info.port, info.server), 0) | |
1399 out.addAnswerAtTime(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) | |
1400 if info.address: | |
1401 out.addAnswerAtTime(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0) | |
1402 self.send(out) | |
1403 i += 1 | |
1404 nextTime += _UNREGISTER_TIME | |
1405 | |
1406 def checkService(self, info): | |
1407 """Checks the network for a unique service name, modifying the | |
1408 ServiceInfo passed in if it is not unique.""" | |
1409 now = currentTimeMillis() | |
1410 nextTime = now | |
1411 i = 0 | |
1412 while i < 3: | |
1413 for record in self.cache.entriesWithName(info.type): | |
1414 if record.type == _TYPE_PTR and not record.isExpired(now) and record.alias == info.name: | |
1415 if (info.name.find('.') < 0): | |
1416 info.name = info.name + ".[" + info.address + ":" + info.port + "]." + info.type | |
1417 self.checkService(info) | |
1418 return | |
1419 raise NonUniqueNameException | |
1420 if now < nextTime: | |
1421 self.wait(nextTime - now) | |
1422 now = currentTimeMillis() | |
1423 continue | |
1424 out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) | |
1425 self.debug = out | |
1426 out.addQuestion(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) | |
1427 out.addAuthorativeAnswer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, info.name)) | |
1428 self.send(out) | |
1429 i += 1 | |
1430 nextTime += _CHECK_TIME | |
1431 | |
1432 def addListener(self, listener, question): | |
1433 """Adds a listener for a given question. The listener will have | |
1434 its updateRecord method called when information is available to | |
1435 answer the question.""" | |
1436 now = currentTimeMillis() | |
1437 self.listeners.append(listener) | |
1438 if question is not None: | |
1439 for record in self.cache.entriesWithName(question.name): | |
1440 if question.answeredBy(record) and not record.isExpired(now): | |
1441 listener.updateRecord(self, now, record) | |
1442 self.notifyAll() | |
1443 | |
1444 def removeListener(self, listener): | |
1445 """Removes a listener.""" | |
1446 try: | |
1447 self.listeners.remove(listener) | |
1448 self.notifyAll() | |
1449 except: | |
1450 pass | |
1451 | |
1452 def updateRecord(self, now, rec): | |
1453 """Used to notify listeners of new information that has updated | |
1454 a record.""" | |
1455 for listener in self.listeners: | |
1456 listener.updateRecord(self, now, rec) | |
1457 self.notifyAll() | |
1458 | |
1459 def handleResponse(self, msg): | |
1460 """Deal with incoming response packets. All answers | |
1461 are held in the cache, and listeners are notified.""" | |
1462 now = currentTimeMillis() | |
1463 for record in msg.answers: | |
1464 expired = record.isExpired(now) | |
1465 if record in self.cache.entries(): | |
1466 if expired: | |
1467 self.cache.remove(record) | |
1468 else: | |
1469 entry = self.cache.get(record) | |
1470 if entry is not None: | |
1471 entry.resetTTL(record) | |
1472 record = entry | |
1473 else: | |
1474 self.cache.add(record) | |
1475 | |
1476 self.updateRecord(now, record) | |
1477 | |
1478 def handleQuery(self, msg, addr, port): | |
1479 """Deal with incoming query packets. Provides a response if | |
1480 possible.""" | |
1481 out = None | |
1482 | |
1483 # Support unicast client responses | |
1484 # | |
1485 if port != _MDNS_PORT: | |
1486 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, 0) | |
1487 for question in msg.questions: | |
1488 out.addQuestion(question) | |
1489 | |
1490 for question in msg.questions: | |
1491 if question.type == _TYPE_PTR: | |
1492 if question.name == "_services._dns-sd._udp.local.": | |
1493 for stype in self.servicetypes.keys(): | |
1494 if out is None: | |
1495 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) | |
1496 out.addAnswer(msg, DNSPointer("_services._dns-sd._udp.local.", _TYPE_PTR, _CLASS_IN, _DNS_TTL, stype)) | |
1497 for service in self.services.values(): | |
1498 if question.name == service.type: | |
1499 if out is None: | |
1500 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) | |
1501 out.addAnswer(msg, DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, _DNS_TTL, service.name)) | |
1502 else: | |
1503 try: | |
1504 if out is None: | |
1505 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) | |
1506 | |
1507 # Answer A record queries for any service addresses we know | |
1508 if question.type == _TYPE_A or question.type == _TYPE_ANY: | |
1509 for service in self.services.values(): | |
1510 if service.server == question.name.lower(): | |
1511 out.addAnswer(msg, DNSAddress(question.name, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.address)) | |
1512 | |
1513 service = self.services.get(question.name.lower(), None) | |
1514 if not service: continue | |
1515 | |
1516 if question.type == _TYPE_SRV or question.type == _TYPE_ANY: | |
1517 out.addAnswer(msg, DNSService(question.name, _TYPE_SRV, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.priority, service.weight, service.port, service.server)) | |
1518 if question.type == _TYPE_TXT or question.type == _TYPE_ANY: | |
1519 out.addAnswer(msg, DNSText(question.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.text)) | |
1520 if question.type == _TYPE_SRV: | |
1521 out.addAdditionalAnswer(DNSAddress(service.server, _TYPE_A, _CLASS_IN | _CLASS_UNIQUE, _DNS_TTL, service.address)) | |
1522 except: | |
1523 traceback.print_exc() | |
1524 | |
1525 if out is not None and out.answers: | |
1526 out.id = msg.id | |
1527 self.send(out, addr, port) | |
1528 | |
1529 def send(self, out, addr = _MDNS_ADDR, port = _MDNS_PORT): | |
1530 """Sends an outgoing packet.""" | |
1531 # This is a quick test to see if we can parse the packets we generate | |
1532 #temp = DNSIncoming(out.packet()) | |
1533 try: | |
1534 bytes_sent = self.socket.sendto(out.packet(), 0, (addr, port)) | |
1535 except: | |
1536 # Ignore this, it may be a temporary loss of network connection | |
1537 pass | |
1538 | |
1539 def close(self): | |
1540 """Ends the background threads, and prevent this instance from | |
1541 servicing further queries.""" | |
1542 if globals()['_GLOBAL_DONE'] == 0: | |
1543 globals()['_GLOBAL_DONE'] = 1 | |
1544 self.notifyAll() | |
1545 self.engine.notify() | |
1546 self.unregisterAllServices() | |
1547 self.socket.setsockopt(socket.SOL_IP, socket.IP_DROP_MEMBERSHIP, socket.inet_aton(_MDNS_ADDR) + socket.inet_aton('0.0.0.0')) | |
1548 self.socket.close() | |
1549 | |
1550 # Test a few module features, including service registration, service | |
1551 # query (for Zoe), and service unregistration. | |
1552 | |
1553 if __name__ == '__main__': | |
1554 print "Multicast DNS Service Discovery for Python, version", __version__ | |
1555 r = Zeroconf() | |
1556 print "1. Testing registration of a service..." | |
1557 desc = {'version':'0.10','a':'test value', 'b':'another value'} | |
1558 info = ServiceInfo("_http._tcp.local.", "My Service Name._http._tcp.local.", socket.inet_aton("127.0.0.1"), 1234, 0, 0, desc) | |
1559 print " Registering service..." | |
1560 r.registerService(info) | |
1561 print " Registration done." | |
1562 print "2. Testing query of service information..." | |
1563 print " Getting ZOE service:", str(r.getServiceInfo("_http._tcp.local.", "ZOE._http._tcp.local.")) | |
1564 print " Query done." | |
1565 print "3. Testing query of own service..." | |
1566 print " Getting self:", str(r.getServiceInfo("_http._tcp.local.", "My Service Name._http._tcp.local.")) | |
1567 print " Query done." | |
1568 print "4. Testing unregister of service information..." | |
1569 r.unregisterService(info) | |
1570 print " Unregister done." | |
1571 r.close() |