Merge branch 'master' into federate

This commit is contained in:
Valentin Samir
2016-07-01 16:42:31 +02:00
30 changed files with 2788 additions and 1178 deletions

View File

193
cas_server/tests/mixin.py Normal file
View File

@ -0,0 +1,193 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Some mixin classes for tests"""
from cas_server.default_settings import settings
from django.utils import timezone
import re
from lxml import etree
from datetime import timedelta
from cas_server import models
from cas_server.tests.utils import get_auth_client
class BaseServicePattern(object):
"""Mixing for setting up service pattern for testing"""
def setup_service_patterns(self, proxy=False):
"""setting up service pattern"""
# For general purpose testing
self.service = "https://www.example.com"
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$",
proxy=proxy,
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
# For testing the restrict_users attributes
self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
name="restrict_user_fail",
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
proxy=proxy,
)
self.service_restrict_user_success = "https://restrict_user_success.example.com"
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
proxy=proxy,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
# For testing the user attributes filtering conditions
self.service_filter_fail = "https://filter_fail.example.com"
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
name="filter_fail_alt",
pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="nom",
pattern="^toto$",
service_pattern=self.service_pattern_filter_fail_alt
)
self.service_filter_success = "https://filter_success.example.com"
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
# For testing the user_field attributes
self.service_field_needed_fail = "https://field_needed_fail.example.com"
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid",
proxy=proxy,
)
self.service_field_needed_success = "https://field_needed_success.example.com"
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="alias",
proxy=proxy,
)
self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success_alt",
pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
user_field="nom",
proxy=proxy,
)
class XmlContent(object):
"""Mixin for test on CAS XML responses"""
def assert_error(self, response, code, text=None):
"""Assert a validation error"""
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], code)
if text is not None:
self.assertEqual(error[0].text, text)
def assert_success(self, response, username, original_attributes):
"""assert a ticket validation success"""
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, username)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in original_attributes.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
return root
class UserModels(object):
"""Mixin for test on CAS user models"""
@staticmethod
def expire_user():
"""return an expired user"""
client = get_auth_client()
new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
models.User.objects.filter(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).update(date=new_date)
return client
@staticmethod
def get_user(client):
"""return the user associated with an authenticated client"""
return models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)

View File

@ -0,0 +1,84 @@
"""
Django test settings for cas_server application.
Generated by 'django-admin startproject' using Django 1.9.7.
For more information on this file, see
https://docs.djangoproject.com/en/1.9/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.9/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'changeme'
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
ALLOWED_HOSTS = []
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'bootstrap3',
'cas_server',
]
MIDDLEWARE_CLASSES = [
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'django.middleware.locale.LocaleMiddleware',
]
ROOT_URLCONF = 'cas_server.tests.urls'
# Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': ':memory:',
}
}
# Internationalization
# https://docs.djangoproject.com/en/1.9/topics/i18n/
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_L10N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.9/howto/static-files/
STATIC_URL = '/static/'

View File

@ -0,0 +1,166 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Tests module for models"""
from cas_server.default_settings import settings
from django.test import TestCase
from django.test.utils import override_settings
from django.utils import timezone
from datetime import timedelta
from importlib import import_module
from cas_server import models
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class UserTestCase(TestCase, UserModels):
"""tests for the user models"""
def setUp(self):
"""Prepare the test context"""
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_clean_old_entries(self):
"""test clean_old_entries"""
# get an authenticated client
client = self.expire_user()
# assert the user exists before being cleaned
self.assertEqual(len(models.User.objects.all()), 1)
# assert the last activity date is before the expiry date
self.assertTrue(
self.get_user(client).date < (
timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
)
)
# delete old inactive users
models.User.clean_old_entries()
# assert the user has being well delete
self.assertEqual(len(models.User.objects.all()), 0)
def test_clean_deleted_sessions(self):
"""test clean_deleted_sessions"""
# get an authenticated client
client1 = get_auth_client()
client2 = get_auth_client()
# generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
# on self.service)
ticket = self.get_user(client1).get_ticket(
models.ServiceTicket,
self.service,
self.service_pattern,
renew=False
)
ticket.validate = True
ticket.save()
# simulated expired session being garbage collected for client1
session = SessionStore(session_key=client1.session.session_key)
session.flush()
# assert the user exists before being cleaned
self.assertTrue(self.get_user(client1))
self.assertTrue(self.get_user(client2))
self.assertEqual(len(models.User.objects.all()), 2)
# session has being remove so the user of client1 is no longer authenticated
self.assertFalse(client1.session.get("authenticated"))
# the user a client2 should still be authenticated
self.assertTrue(client2.session.get("authenticated"))
# the user should be deleted
models.User.clean_deleted_sessions()
# assert the user with expired sessions has being well deleted but the other remain
self.assertEqual(len(models.User.objects.all()), 1)
self.assertFalse(models.ServiceTicket.objects.all())
self.assertTrue(client2.session.get("authenticated"))
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
"""tests for the tickets models"""
def setUp(self):
"""Prepare the test context"""
self.setup_service_patterns()
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
@staticmethod
def get_ticket(
user,
ticket_class,
service,
service_pattern,
renew=False,
validate=False,
validity_expired=False,
timeout_expired=False,
single_log_out=False,
):
"""Return a ticket"""
ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
ticket.validate = validate
ticket.single_log_out = single_log_out
if validity_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
)
if timeout_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
)
ticket.save()
return ticket
def test_clean_old_service_ticket(self):
"""test tickets clean_old_entries"""
# ge an authenticated client
client = get_auth_client()
# get the user associated to the client
user = self.get_user(client)
# generate a ticket for that client, waiting for validation
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
# generate another ticket for those validation time has expired
self.get_ticket(
user, models.ServiceTicket,
self.service, self.service_pattern, validity_expired=True
)
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
# generate a ticket with SLO having timeout reach
self.get_ticket(
user, models.ServiceTicket,
service, self.service_pattern, timeout_expired=True,
validate=True, single_log_out=True
)
# there should be 3 tickets in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
# we call the clean_old_entries method that should delete validated non SLO ticket and
# expired non validated ticket and send SLO for SLO expired ticket before deleting then
models.ServiceTicket.clean_old_entries()
params = httpd.PARAMS
# we successfully got a SLO request
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
# only 1 ticket remain in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)

View File

@ -0,0 +1,191 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Tests module for utils"""
from django.test import TestCase
import six
from cas_server import utils
class CheckPasswordCase(TestCase):
"""Tests for the utils function `utils.check_password`"""
def setUp(self):
"""Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
"""check that generated password are bytes"""
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
"""test the plain auth method"""
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_plain_unicode(self):
"""test the plain auth method with unicode input"""
self.assertTrue(
utils.check_password(
"plain",
self.password1.decode("utf8"),
self.password1.decode("utf8"),
"utf8"
)
)
self.assertFalse(
utils.check_password(
"plain",
self.password1.decode("utf8"),
self.password2.decode("utf8"),
"utf8"
)
)
def test_crypt(self):
"""test the crypt auth method"""
salts = ["$6$UVVAQvrMyXMF3FF3", "aa"]
hashed_password1 = []
for salt in salts:
if six.PY3:
hashed_password1.append(
utils.crypt.crypt(
self.password1.decode("utf8"),
salt
).encode("utf8")
)
else:
hashed_password1.append(utils.crypt.crypt(self.password1, salt))
for hp1 in hashed_password1:
self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8"))
with self.assertRaises(ValueError):
utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8")
def test_ldap_password_valid(self):
"""test the ldap auth method with all the schemes"""
salt = b"UVVAQvrMyXMF3FF3"
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
hashed_password1 = []
for scheme in schemes_salt:
hashed_password1.append(
utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8")
)
for scheme in schemes_nosalt:
hashed_password1.append(
utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8")
)
hashed_password1.append(
utils.LdapHashUserPassword.hash(
b"{CRYPT}",
self.password1,
b"$6$UVVAQvrMyXMF3FF3",
charset="utf8"
)
)
for hp1 in hashed_password1:
self.assertIsInstance(hp1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8"))
def test_ldap_password_fail(self):
"""test the ldap auth method with malformed hash or bad schemes"""
salt = b"UVVAQvrMyXMF3FF3"
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
# first try to hash with bad parameters
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(b"TOTO", self.password1)
for scheme in schemes_nosalt:
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(scheme, self.password1, salt)
for scheme in schemes_salt:
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(scheme, self.password1)
with self.assertRaises(utils.LdapHashUserPassword.BadSalt):
utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto")
# then try to check hash with bad hashes
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8")
for scheme in schemes_salt:
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8")
def test_hex(self):
"""test all the hex_HASH method: the hashed password is a simple hash of the password"""
hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]
hashed_password1 = []
for hash in hashes:
hashed_password1.append(
("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest())
)
for (method, hp1) in hashed_password1:
self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8"))
def test_bad_method(self):
"""try to check password with a bad method, should raise a ValueError"""
with self.assertRaises(ValueError):
utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8")
class UtilsTestCase(TestCase):
"""tests for some little utils functions"""
def test_import_attr(self):
"""
test the import_attr function. Feeded with a dotted path string, it should
import the dotted module and return that last componend of the dotted path
(function, class or variable)
"""
with self.assertRaises(ImportError):
utils.import_attr('toto.titi.tutu')
with self.assertRaises(AttributeError):
utils.import_attr('cas_server.utils.toto')
with self.assertRaises(ValueError):
utils.import_attr('toto')
self.assertEqual(
utils.import_attr('cas_server.default_app_config'),
'cas_server.apps.CasAppConfig'
)
self.assertEqual(utils.import_attr(utils), utils)
def test_update_url(self):
"""
test the update_url function. Given an url with possible GET parameter and a dict
the function build a url with GET parameters updated by the dictionnary
"""
url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"})
url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"})
self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1")
self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1")
url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"})
self.assertEqual(url3, u"https://www.example.com?toto=2")
def test_crypt_salt_is_valid(self):
"""test the function crypt_salt_is_valid who test if a crypt salt is valid"""
self.assertFalse(utils.crypt_salt_is_valid("")) # len 0
self.assertFalse(utils.crypt_salt_is_valid("a")) # len 1
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known

File diff suppressed because it is too large Load Diff

22
cas_server/tests/urls.py Normal file
View File

@ -0,0 +1,22 @@
"""cas URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/1.9/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.conf.urls import url, include, include
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
"""
from django.conf.urls import url, include
from django.contrib import admin
urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^', include('cas_server.urls', namespace='cas_server')),
]

180
cas_server/tests/utils.py Normal file
View File

@ -0,0 +1,180 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Some utils functions for tests"""
from cas_server.default_settings import settings
from django.test import Client
import cgi
from threading import Thread
from lxml import etree
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl
from cas_server import models
def copy_form(form):
"""Copy form value into a dict"""
params = {}
for field in form:
if field.value():
params[field.name] = field.value()
else:
params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params
def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params)
assert client.session.get("authenticated")
return client
def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client()
response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
return (user, ticket, client)
def get_validated_ticket(service):
"""Return a tick that has being already validated. Used to test SLO"""
(ticket, auth_client) = get_user_ticket_request(service)[1:3]
client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': service})
assert (response.status_code == 200)
assert (response.content == b'yes\ntest\n')
ticket = models.ServiceTicket.objects.get(value=ticket.value)
return (auth_client, ticket)
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)[:2]
client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = httpd.PARAMS
params["service"] = service
params["user"] = user
return params
def get_proxy_ticket(service):
"""Return a ProxyTicket waiting for validation"""
params = get_pgt()
# get a proxy ticket
client = Client()
response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
root = etree.fromstring(response.content)
proxy_ticket = root.xpath(
"//cas:proxyTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
proxy_ticket = proxy_ticket[0].text
ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
return ticket
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple http server that return 200 on GET or POST
and store GET or POST parameters. Used in unit tests
"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
self.server.PARAMS = params
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data':
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded':
length = int(self.headers.get('content-length'))
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
else:
postvars = {}
self.server.PARAMS = postvars
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
class Http404Handler(HttpParamsHandler):
"""A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"error 404 not found")
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()