Add unit tests for when CAS_FEDERATE is True
Also fix some unicode related bugs
This commit is contained in:
@ -13,14 +13,33 @@
|
||||
from cas_server.default_settings import settings
|
||||
|
||||
from django.test import Client
|
||||
from django.template import loader, Context
|
||||
from django.utils import timezone
|
||||
|
||||
import cgi
|
||||
import six
|
||||
from threading import Thread
|
||||
from lxml import etree
|
||||
from six.moves import BaseHTTPServer
|
||||
from six.moves.urllib.parse import urlparse, parse_qsl
|
||||
from datetime import timedelta
|
||||
|
||||
from cas_server import models
|
||||
from cas_server import utils
|
||||
|
||||
|
||||
def return_unicode(string, charset):
|
||||
if not isinstance(string, six.text_type):
|
||||
return string.decode(charset)
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def return_bytes(string, charset):
|
||||
if isinstance(string, six.text_type):
|
||||
return string.encode(charset)
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def copy_form(form):
|
||||
@ -149,10 +168,10 @@ class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def run(cls):
|
||||
def run(cls, port=0):
|
||||
"""Run a BaseHTTPServer using this class as handler"""
|
||||
server_class = BaseHTTPServer.HTTPServer
|
||||
httpd = server_class(("127.0.0.1", 0), cls)
|
||||
httpd = server_class(("127.0.0.1", port), cls)
|
||||
(host, port) = httpd.socket.getsockname()
|
||||
|
||||
def lauch():
|
||||
@ -178,3 +197,143 @@ class Http404Handler(HttpParamsHandler):
|
||||
def do_POST(self):
|
||||
"""Called on a POST request on the BaseHTTPServer"""
|
||||
return self.do_GET()
|
||||
|
||||
|
||||
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
|
||||
def test_params(self):
|
||||
if (
|
||||
self.server.ticket is not None and
|
||||
self.params.get("service").encode("ascii") == self.server.service and
|
||||
self.params.get("ticket").encode("ascii") == self.server.ticket
|
||||
):
|
||||
self.server.ticket = None
|
||||
print("good")
|
||||
return True
|
||||
else:
|
||||
print("bad (%r, %r) != (%r, %r)" % (
|
||||
self.params.get("service").encode("ascii"),
|
||||
self.params.get("ticket").encode("ascii"),
|
||||
self.server.service,
|
||||
self.server.ticket
|
||||
))
|
||||
|
||||
return False
|
||||
|
||||
def send_headers(self, code, content_type):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", content_type)
|
||||
self.end_headers()
|
||||
|
||||
def do_GET(self):
|
||||
url = urlparse(self.path)
|
||||
self.params = dict(parse_qsl(url.query))
|
||||
if url.path == "/validate":
|
||||
self.send_headers(200, "text/plain; charset=utf-8")
|
||||
if self.test_params():
|
||||
self.wfile.write(b"yes\n" + self.server.username + b"\n")
|
||||
self.server.ticket = None
|
||||
else:
|
||||
self.wfile.write(b"no\n")
|
||||
elif url.path in {
|
||||
'/serviceValidate', '/serviceValidate',
|
||||
'/p3/serviceValidate', '/p3/proxyValidate'
|
||||
}:
|
||||
self.send_headers(200, "text/xml; charset=utf-8")
|
||||
if self.test_params():
|
||||
t = loader.get_template('cas_server/serviceValidate.xml')
|
||||
c = Context({
|
||||
'username': self.server.username,
|
||||
'attributes': self.server.attributes
|
||||
})
|
||||
self.wfile.write(return_bytes(t.render(c), "utf8"))
|
||||
else:
|
||||
t = loader.get_template('cas_server/serviceValidateError.xml')
|
||||
c = Context({
|
||||
'code': 'BAD_SERVICE_TICKET',
|
||||
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
|
||||
})
|
||||
self.wfile.write(return_bytes(t.render(c), "utf8"))
|
||||
else:
|
||||
self.return_404()
|
||||
|
||||
def do_POST(self):
|
||||
url = urlparse(self.path)
|
||||
self.params = dict(parse_qsl(url.query))
|
||||
if url.path == "/samlValidate":
|
||||
self.send_headers(200, "text/xml; charset=utf-8")
|
||||
length = int(self.headers.get('content-length'))
|
||||
root = etree.fromstring(self.rfile.read(length))
|
||||
auth_req = root.getchildren()[1].getchildren()[0]
|
||||
ticket = auth_req.getchildren()[0].text.encode("ascii")
|
||||
if (
|
||||
self.server.ticket is not None and
|
||||
self.params.get("TARGET").encode("ascii") == self.server.service and
|
||||
ticket == self.server.ticket
|
||||
):
|
||||
self.server.ticket = None
|
||||
t = loader.get_template('cas_server/samlValidate.xml')
|
||||
c = Context({
|
||||
'IssueInstant': timezone.now().isoformat(),
|
||||
'expireInstant': (timezone.now() + timedelta(seconds=60)).isoformat(),
|
||||
'Recipient': self.server.service,
|
||||
'ResponseID': utils.gen_saml_id(),
|
||||
'username': self.server.username,
|
||||
'attributes': self.server.attributes,
|
||||
})
|
||||
self.wfile.write(return_bytes(t.render(c), "utf8"))
|
||||
else:
|
||||
t = loader.get_template('cas_server/samlValidateError.xml')
|
||||
c = Context({
|
||||
'IssueInstant': timezone.now().isoformat(),
|
||||
'ResponseID': utils.gen_saml_id(),
|
||||
'code': 'BAD_SERVICE_TICKET',
|
||||
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
|
||||
})
|
||||
self.wfile.write(return_bytes(t.render(c), "utf8"))
|
||||
else:
|
||||
self.return_404()
|
||||
|
||||
def return_404(self):
|
||||
self.send_response(404)
|
||||
self.send_header(b"Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write("not found")
|
||||
|
||||
def log_message(self, *args):
|
||||
"""silent any log message"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def run(cls, service, ticket, username, attributes, port=0):
|
||||
"""Run a BaseHTTPServer using this class as handler"""
|
||||
server_class = BaseHTTPServer.HTTPServer
|
||||
httpd = server_class(("127.0.0.1", port), cls)
|
||||
httpd.service = service
|
||||
httpd.ticket = ticket
|
||||
httpd.username = username
|
||||
httpd.attributes = attributes
|
||||
(host, port) = httpd.socket.getsockname()
|
||||
|
||||
def lauch():
|
||||
"""routine to lauch in a background thread"""
|
||||
httpd.handle_request()
|
||||
httpd.server_close()
|
||||
|
||||
httpd_thread = Thread(target=lauch)
|
||||
httpd_thread.daemon = True
|
||||
httpd_thread.start()
|
||||
return (httpd, host, port)
|
||||
|
||||
|
||||
def logout_request(ticket):
|
||||
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
||||
</samlp:LogoutRequest>""" % \
|
||||
{
|
||||
'id': utils.gen_saml_id(),
|
||||
'datetime': timezone.now().isoformat(),
|
||||
'ticket': ticket
|
||||
}
|
||||
|
Reference in New Issue
Block a user