Download Install Tutorial Docs FAQ Tools WikiLicense Team IRC Planet Involvement Shop Book

root/branches/cp3-wsgi-remix/test/webtest.py

Revision 1186 (checked in by fumanchu, 2 years ago)

webtest.WebCase?.assetHeader now returns the header value if found.

  • Property svn:eol-style set to native
Line 
1 """Extensions to unittest for web frameworks.
2
3 Use the WebCase.getPage method to request a page from your HTTP server.
4
5 Framework Integration
6 =====================
7
8 If you have control over your server process, you can handle errors
9 in the server-side of the HTTP conversation a bit better. You must run
10 both the client (your WebCase tests) and the server in the same process
11 (but in separate threads, obviously).
12
13 When an error occurs in the framework, call server_error. It will print
14 the traceback to stdout, and keep any assertions you have from running
15 (the assumption is that, if the server errors, the page output won't be
16 of further significance to your tests).
17 """
18
19 import os, sys, time, re
20 import types
21 import pprint
22 import socket
23 import httplib
24 import traceback
25
26 from unittest import *
27 from unittest import _TextTestResult
28
29
30 class TerseTestResult(_TextTestResult):
31    
32     def printErrors(self):
33         # Overridden to avoid unnecessary empty line
34         if self.errors or self.failures:
35             if self.dots or self.showAll:
36                 self.stream.writeln()
37             self.printErrorList('ERROR', self.errors)
38             self.printErrorList('FAIL', self.failures)
39
40
41 class TerseTestRunner(TextTestRunner):
42     """A test runner class that displays results in textual form."""
43    
44     def _makeResult(self):
45         return TerseTestResult(self.stream, self.descriptions, self.verbosity)
46    
47     def run(self, test):
48         "Run the given test case or test suite."
49         # Overridden to remove unnecessary empty lines and separators
50         result = self._makeResult()
51         startTime = time.time()
52         test(result)
53         timeTaken = float(time.time() - startTime)
54         result.printErrors()
55         if not result.wasSuccessful():
56             self.stream.write("FAILED (")
57             failed, errored = map(len, (result.failures, result.errors))
58             if failed:
59                 self.stream.write("failures=%d" % failed)
60             if errored:
61                 if failed: self.stream.write(", ")
62                 self.stream.write("errors=%d" % errored)
63             self.stream.writeln(")")
64         return result
65
66
67 class ReloadingTestLoader(TestLoader):
68    
69     def loadTestsFromName(self, name, module=None):
70         """Return a suite of all tests cases given a string specifier.
71
72         The name may resolve either to a module, a test case class, a
73         test method within a test case class, or a callable object which
74         returns a TestCase or TestSuite instance.
75
76         The method optionally resolves the names relative to a given module.
77         """
78         parts = name.split('.')
79         if module is None:
80             if not parts:
81                 raise ValueError("incomplete test name: %s" % name)
82             else:
83                 parts_copy = parts[:]
84                 while parts_copy:
85                     target = ".".join(parts_copy)
86                     if target in sys.modules:
87                         module = reload(sys.modules[target])
88                         break
89                     else:
90                         try:
91                             module = __import__(target)
92                             break
93                         except ImportError:
94                             del parts_copy[-1]
95                             if not parts_copy:
96                                 raise
97                 parts = parts[1:]
98         obj = module
99         for part in parts:
100             obj = getattr(obj, part)
101        
102         if type(obj) == types.ModuleType:
103             return self.loadTestsFromModule(obj)
104         elif (isinstance(obj, (type, types.ClassType)) and
105               issubclass(obj, TestCase)):
106             return self.loadTestsFromTestCase(obj)
107         elif type(obj) == types.UnboundMethodType:
108             return obj.im_class(obj.__name__)
109         elif callable(obj):
110             test = obj()
111             if not isinstance(test, TestCase) and \
112                not isinstance(test, TestSuite):
113                 raise ValueError("calling %s returned %s, "
114                                  "not a test" % (obj,test))
115             return test
116         else:
117             raise ValueError("don't know how to make test from: %s" % obj)
118
119
120 try:
121     # On Windows, msvcrt.getch reads a single char without output.
122     import msvcrt
123     def getchar():
124         return msvcrt.getch()
125 except ImportError:
126     # Unix getchr
127     import tty, termios
128     def getchar():
129         fd = sys.stdin.fileno()
130         old_settings = termios.tcgetattr(fd)
131         try:
132             tty.setraw(sys.stdin.fileno())
133             ch = sys.stdin.read(1)
134         finally:
135             termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
136         return ch
137
138
139 class WebCase(TestCase):
140     HOST = "127.0.0.1"
141     PORT = 8000
142     HTTP_CONN=httplib.HTTPConnection
143    
144     def getPage(self, url, headers=None, method="GET", body=None, protocol="HTTP/1.1"):
145         """Open the url with debugging support. Return status, headers, body."""
146         ServerError.on = False
147        
148         self.url = url
149         result = openURL(url, headers, method, body, self.HOST, self.PORT,
150                          self.HTTP_CONN, protocol)
151         self.status, self.headers, self.body = result
152        
153         # Build a list of request cookies from the previous response cookies.
154         self.cookies = [('Cookie', v) for k, v in self.headers
155                         if k.lower() == 'set-cookie']
156        
157         if ServerError.on:
158             raise ServerError()
159         return result
160    
161     interactive = True
162     console_height = 30
163    
164     def _handlewebError(self, msg):
165         print
166         print "    ERROR:", msg
167        
168         if not self.interactive:
169             raise self.failureException(msg)
170        
171         p = "    Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> "
172         print p,
173         while True:
174             i = getchar().upper()
175             if i not in "BHSUIRX":
176                 continue
177             print i.upper()  # Also prints new line
178             if i == "B":
179                 for x, line in enumerate(self.body.splitlines()):
180                     if (x + 1) % self.console_height == 0:
181                         # The \r and comma should make the next line overwrite
182                         print "<-- More -->\r",
183                         m = getchar().lower()
184                         # Erase our "More" prompt
185                         print "            \r",
186                         if m == "q":
187                             break
188                     print line
189             elif i == "H":
190                 pprint.pprint(self.headers)
191             elif i == "S":
192                 print self.status
193             elif i == "U":
194                 print self.url
195             elif i == "I":
196                 # return without raising the normal exception
197                 return
198             elif i == "R":
199                 raise self.failureException(msg)
200             elif i == "X":
201                 self.exit()
202             print p,
203    
204     def exit(self):
205         sys.exit()
206    
207     def __call__(self, result=None):
208         if result is None:
209             result = self.defaultTestResult()
210         result.startTest(self)
211         testMethod = getattr(self, self._TestCase__testMethodName)
212         try:
213             try:
214                 self.setUp()
215             except (KeyboardInterrupt, SystemExit):
216                 raise
217             except:
218                 result.addError(self, self._TestCase__exc_info())
219                 return
220            
221             ok = 0
222             try:
223                 testMethod()
224                 ok = 1
225             except self.failureException:
226                 result.addFailure(self, self._TestCase__exc_info())
227             except (KeyboardInterrupt, SystemExit):
228                 raise
229             except:
230                 result.addError(self, self._TestCase__exc_info())
231            
232             try:
233                 self.tearDown()
234             except (KeyboardInterrupt, SystemExit):
235                 raise
236             except:
237                 result.addError(self, self._TestCase__exc_info())
238                 ok = 0
239             if ok:
240                 result.addSuccess(self)
241         finally:
242             result.stopTest(self)
243    
244     def assertStatus(self, status, msg=None):
245         """Fail if self.status != status."""
246         if isinstance(status, basestring):
247             if not self.status == status:
248                 if msg is None:
249                     msg = 'Status (%s) != %s' % (`self.status`, `status`)
250                 self._handlewebError(msg)
251         elif isinstance(status, int):
252             code = int(self.status[:3])
253             if code != status:
254                 if msg is None:
255                     msg = 'Status (%s) != %s' % (`self.status`, `status`)
256                 self._handlewebError(msg)
257         else:
258             # status is a tuple or list.
259             match = False
260             for s in status:
261                 if isinstance(s, basestring):
262                     if self.status == s:
263                         match = True
264                         break
265                 elif int(self.status[:3]) == s:
266                     match = True
267                     break
268             if not match:
269                 if msg is None:
270                     msg = 'Status (%s) not in %s' % (`self.status`, `status`)
271                 self._handlewebError(msg)
272    
273     def assertHeader(self, key, value=None, msg=None):
274         """Fail if (key, [value]) not in self.headers."""
275         lowkey = key.lower()
276         for k, v in self.headers:
277             if k.lower() == lowkey:
278                 if value is None or str(value) == v:
279                     return v
280        
281         if msg is None:
282             if value is None:
283                 msg = '%s not in headers' % `key`
284             else:
285                 msg = '%s:%s not in headers' % (`key`, `value`)
286         self._handlewebError(msg)
287    
288     def assertNoHeader(self, key, msg=None):
289         """Fail if key in self.headers."""
290         lowkey = key.lower()
291         matches = [k for k, v in self.headers if k.lower() == lowkey]
292         if matches:
293             if msg is None:
294                 msg = '%s in headers' % `key`
295             self._handlewebError(msg)
296    
297     def assertBody(self, value, msg=None):
298         """Fail if value != self.body."""
299         if value != self.body:
300             if msg is None:
301                 msg = 'expected body:\n%s\n\nactual body:\n%s' % (`value`, `self.body`)
302             self._handlewebError(msg)
303    
304     def assertInBody(self, value, msg=None):
305         """Fail if value not in self.body."""
306         if value not in self.body:
307             if msg is None:
308                 msg = '%s not in body' % `value`
309             self._handlewebError(msg)
310    
311     def assertNotInBody(self, value, msg=None):
312         """Fail if value in self.body."""
313         if value in self.body:
314             if msg is None:
315                 msg = '%s found in body' % `value`
316             self._handlewebError(msg)
317    
318     def assertMatchesBody(self, pattern, msg=None, flags=0):
319         """Fail if value (a regex pattern) is not in self.body."""
320         if re.search(pattern, self.body, flags) is None:
321             if msg is None:
322                 msg = 'No match for %s in body' % `pattern`
323             self._handlewebError(msg)
324
325
326 methods_with_bodies = ("POST", "PUT")
327
328 def cleanHeaders(headers, method, body, host, port):
329     """Return request headers, with required headers added (if missing)."""
330     if headers is None:
331         headers = []
332    
333     # Add the required Host request header if not present.
334     # [This specifies the host:port of the server, not the client.]
335     found = False
336     for k, v in headers:
337         if k.lower() == 'host':
338             found = True
339             break
340     if not found:
341         headers.append(("Host", "%s:%s" % (host, port)))
342    
343     if method in methods_with_bodies:
344         # Stick in default type and length headers if not present
345         found = False
346         for k, v in headers:
347             if k.lower() == 'content-type':
348                 found = True
349                 break
350         if not found:
351             headers.append(("Content-Type", "application/x-www-form-urlencoded"))
352             headers.append(("Content-Length", str(len(body or ""))))
353    
354     return headers
355
356
357 def openURL(url, headers=None, method="GET", body=None,
358             host="127.0.0.1", port=8000, http_conn=httplib.HTTPConnection,
359             protocol="HTTP/1.1"):
360     """Open the given HTTP resource and return status, headers, and body."""
361    
362     headers = cleanHeaders(headers, method, body, host, port)
363    
364     # Trying 10 times is simply in case of socket errors.
365     # Normal case--it should run once.
366     trial = 0
367     while trial < 10:
368         try:
369             conn = http_conn(host, port)
370             conn._http_vsn_str = protocol
371             conn._http_vsn = int("".join([x for x in protocol if x.isdigit()]))
372            
373             # skip_accept_encoding argument added in python version 2.4
374             if sys.version_info < (2, 4):
375                 conn.putrequest(method.upper(), url, skip_host=True)
376             else:
377                 conn.putrequest(method.upper(), url, skip_host=True,
378                                 skip_accept_encoding=True)
379            
380             for key, value in headers:
381                 conn.putheader(key, value)
382             conn.endheaders()
383            
384             if body is not None:
385                 conn.send(body)
386            
387             # Handle response
388             response = conn.getresponse()
389            
390             status = "%s %s" % (response.status, response.reason)
391            
392             outheaders = []
393             key, value = None, None
394             for line in response.msg.headers:
395                 if line:
396                     if line[0] in " \t":
397                         value += line.strip()
398                     else:
399                         if key and value:
400                             outheaders.append((key, value))
401                         key, value = line.split(":", 1)
402                         key = key.strip()
403                         value = value.strip()
404             if key and value:
405                 outheaders.append((key, value))
406            
407             outbody = response.read()
408            
409             conn.close()
410             return status, outheaders, outbody
411         except socket.error:
412             trial += 1
413             if trial >= 10:
414                 raise
415             else:
416                 time.sleep(0.5)
417
418
419 # Add any exceptions which your web framework handles
420 # normally (that you don't want server_error to trap).
421 ignored_exceptions = []
422
423 # You'll want set this to True when you can't guarantee
424 # that each response will immediately follow each request;
425 # for example, when handling requests via multiple threads.
426 ignore_all = False
427
428 class ServerError(Exception):
429     on = False
430
431
432 def server_error(exc=None):
433     """Server debug hook. Return True if exception handled, False if ignored.
434     
435     You probably want to wrap this, so you can still handle an error using
436     your framework when it's ignored.
437     """
438     if exc is None:
439         exc = sys.exc_info()
440    
441     if ignore_all or exc[0] in ignored_exceptions:
442         return False
443     else:
444         ServerError.on = True
445         print
446         print "".join(traceback.format_exception(*exc))
447         return True
448
Note: See TracBrowser for help on using the browser.

Hosted by WebFaction

Log in as guest/cpguest to create tickets