Add unit tests for when CAS_FEDERATE is True

Also fix some unicode related bugs
This commit is contained in:
Valentin Samir
2016-07-03 13:51:00 +02:00
parent fcd906ca78
commit 90daf3d2a0
13 changed files with 749 additions and 144 deletions

View File

@ -13,14 +13,33 @@
from cas_server.default_settings import settings
from django.test import Client
from django.template import loader, Context
from django.utils import timezone
import cgi
import six
from threading import Thread
from lxml import etree
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl
from datetime import timedelta
from cas_server import models
from cas_server import utils
def return_unicode(string, charset):
if not isinstance(string, six.text_type):
return string.decode(charset)
else:
return string
def return_bytes(string, charset):
if isinstance(string, six.text_type):
return string.encode(charset)
else:
return string
def copy_form(form):
@ -149,10 +168,10 @@ class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
return
@classmethod
def run(cls):
def run(cls, port=0):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
httpd = server_class(("127.0.0.1", port), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
@ -178,3 +197,143 @@ class Http404Handler(HttpParamsHandler):
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
def test_params(self):
if (
self.server.ticket is not None and
self.params.get("service").encode("ascii") == self.server.service and
self.params.get("ticket").encode("ascii") == self.server.ticket
):
self.server.ticket = None
print("good")
return True
else:
print("bad (%r, %r) != (%r, %r)" % (
self.params.get("service").encode("ascii"),
self.params.get("ticket").encode("ascii"),
self.server.service,
self.server.ticket
))
return False
def send_headers(self, code, content_type):
self.send_response(200)
self.send_header("Content-type", content_type)
self.end_headers()
def do_GET(self):
url = urlparse(self.path)
self.params = dict(parse_qsl(url.query))
if url.path == "/validate":
self.send_headers(200, "text/plain; charset=utf-8")
if self.test_params():
self.wfile.write(b"yes\n" + self.server.username + b"\n")
self.server.ticket = None
else:
self.wfile.write(b"no\n")
elif url.path in {
'/serviceValidate', '/serviceValidate',
'/p3/serviceValidate', '/p3/proxyValidate'
}:
self.send_headers(200, "text/xml; charset=utf-8")
if self.test_params():
t = loader.get_template('cas_server/serviceValidate.xml')
c = Context({
'username': self.server.username,
'attributes': self.server.attributes
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
t = loader.get_template('cas_server/serviceValidateError.xml')
c = Context({
'code': 'BAD_SERVICE_TICKET',
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
self.return_404()
def do_POST(self):
url = urlparse(self.path)
self.params = dict(parse_qsl(url.query))
if url.path == "/samlValidate":
self.send_headers(200, "text/xml; charset=utf-8")
length = int(self.headers.get('content-length'))
root = etree.fromstring(self.rfile.read(length))
auth_req = root.getchildren()[1].getchildren()[0]
ticket = auth_req.getchildren()[0].text.encode("ascii")
if (
self.server.ticket is not None and
self.params.get("TARGET").encode("ascii") == self.server.service and
ticket == self.server.ticket
):
self.server.ticket = None
t = loader.get_template('cas_server/samlValidate.xml')
c = Context({
'IssueInstant': timezone.now().isoformat(),
'expireInstant': (timezone.now() + timedelta(seconds=60)).isoformat(),
'Recipient': self.server.service,
'ResponseID': utils.gen_saml_id(),
'username': self.server.username,
'attributes': self.server.attributes,
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
t = loader.get_template('cas_server/samlValidateError.xml')
c = Context({
'IssueInstant': timezone.now().isoformat(),
'ResponseID': utils.gen_saml_id(),
'code': 'BAD_SERVICE_TICKET',
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
self.return_404()
def return_404(self):
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write("not found")
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls, service, ticket, username, attributes, port=0):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", port), cls)
httpd.service = service
httpd.ticket = ticket
httpd.username = username
httpd.attributes = attributes
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
def logout_request(ticket):
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
</samlp:LogoutRequest>""" % \
{
'id': utils.gen_saml_id(),
'datetime': timezone.now().isoformat(),
'ticket': ticket
}