Merge branch 'master' into federate
This commit is contained in:
0
cas_server/tests/__init__.py
Normal file
0
cas_server/tests/__init__.py
Normal file
193
cas_server/tests/mixin.py
Normal file
193
cas_server/tests/mixin.py
Normal file
@ -0,0 +1,193 @@
|
||||
# ⁻*- coding: utf-8 -*-
|
||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||
# more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License version 3
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
#
|
||||
# (c) 2016 Valentin Samir
|
||||
"""Some mixin classes for tests"""
|
||||
from cas_server.default_settings import settings
|
||||
from django.utils import timezone
|
||||
|
||||
import re
|
||||
from lxml import etree
|
||||
from datetime import timedelta
|
||||
|
||||
from cas_server import models
|
||||
from cas_server.tests.utils import get_auth_client
|
||||
|
||||
|
||||
class BaseServicePattern(object):
|
||||
"""Mixing for setting up service pattern for testing"""
|
||||
def setup_service_patterns(self, proxy=False):
|
||||
"""setting up service pattern"""
|
||||
# For general purpose testing
|
||||
self.service = "https://www.example.com"
|
||||
self.service_pattern = models.ServicePattern.objects.create(
|
||||
name="example",
|
||||
pattern="^https://www\.example\.com(/.*)?$",
|
||||
proxy=proxy,
|
||||
)
|
||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||
|
||||
# For testing the restrict_users attributes
|
||||
self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
|
||||
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
|
||||
name="restrict_user_fail",
|
||||
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
|
||||
restrict_users=True,
|
||||
proxy=proxy,
|
||||
)
|
||||
self.service_restrict_user_success = "https://restrict_user_success.example.com"
|
||||
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
|
||||
name="restrict_user_success",
|
||||
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
|
||||
restrict_users=True,
|
||||
proxy=proxy,
|
||||
)
|
||||
models.Username.objects.create(
|
||||
value=settings.CAS_TEST_USER,
|
||||
service_pattern=self.service_pattern_restrict_user_success
|
||||
)
|
||||
|
||||
# For testing the user attributes filtering conditions
|
||||
self.service_filter_fail = "https://filter_fail.example.com"
|
||||
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
|
||||
name="filter_fail",
|
||||
pattern="^https://filter_fail\.example\.com(/.*)?$",
|
||||
proxy=proxy,
|
||||
)
|
||||
models.FilterAttributValue.objects.create(
|
||||
attribut="right",
|
||||
pattern="^admin$",
|
||||
service_pattern=self.service_pattern_filter_fail
|
||||
)
|
||||
self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
|
||||
self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
|
||||
name="filter_fail_alt",
|
||||
pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
|
||||
proxy=proxy,
|
||||
)
|
||||
models.FilterAttributValue.objects.create(
|
||||
attribut="nom",
|
||||
pattern="^toto$",
|
||||
service_pattern=self.service_pattern_filter_fail_alt
|
||||
)
|
||||
self.service_filter_success = "https://filter_success.example.com"
|
||||
self.service_pattern_filter_success = models.ServicePattern.objects.create(
|
||||
name="filter_success",
|
||||
pattern="^https://filter_success\.example\.com(/.*)?$",
|
||||
proxy=proxy,
|
||||
)
|
||||
models.FilterAttributValue.objects.create(
|
||||
attribut="email",
|
||||
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
|
||||
service_pattern=self.service_pattern_filter_success
|
||||
)
|
||||
|
||||
# For testing the user_field attributes
|
||||
self.service_field_needed_fail = "https://field_needed_fail.example.com"
|
||||
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
|
||||
name="field_needed_fail",
|
||||
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
|
||||
user_field="uid",
|
||||
proxy=proxy,
|
||||
)
|
||||
self.service_field_needed_success = "https://field_needed_success.example.com"
|
||||
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
|
||||
name="field_needed_success",
|
||||
pattern="^https://field_needed_success\.example\.com(/.*)?$",
|
||||
user_field="alias",
|
||||
proxy=proxy,
|
||||
)
|
||||
self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
|
||||
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
|
||||
name="field_needed_success_alt",
|
||||
pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
|
||||
user_field="nom",
|
||||
proxy=proxy,
|
||||
)
|
||||
|
||||
|
||||
class XmlContent(object):
|
||||
"""Mixin for test on CAS XML responses"""
|
||||
def assert_error(self, response, code, text=None):
|
||||
"""Assert a validation error"""
|
||||
self.assertEqual(response.status_code, 200)
|
||||
root = etree.fromstring(response.content)
|
||||
error = root.xpath(
|
||||
"//cas:authenticationFailure",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(error), 1)
|
||||
self.assertEqual(error[0].attrib['code'], code)
|
||||
if text is not None:
|
||||
self.assertEqual(error[0].text, text)
|
||||
|
||||
def assert_success(self, response, username, original_attributes):
|
||||
"""assert a ticket validation success"""
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
sucess = root.xpath(
|
||||
"//cas:authenticationSuccess",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertTrue(sucess)
|
||||
|
||||
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertEqual(len(users), 1)
|
||||
self.assertEqual(users[0].text, username)
|
||||
|
||||
attributes = root.xpath(
|
||||
"//cas:attributes",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(attributes), 1)
|
||||
attrs1 = set()
|
||||
for attr in attributes[0]:
|
||||
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
|
||||
|
||||
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertEqual(len(attributes), len(attrs1))
|
||||
attrs2 = set()
|
||||
for attr in attributes:
|
||||
attrs2.add((attr.attrib['name'], attr.attrib['value']))
|
||||
original = set()
|
||||
for key, value in original_attributes.items():
|
||||
if isinstance(value, list):
|
||||
for sub_value in value:
|
||||
original.add((key, sub_value))
|
||||
else:
|
||||
original.add((key, value))
|
||||
self.assertEqual(attrs1, attrs2)
|
||||
self.assertEqual(attrs1, original)
|
||||
|
||||
return root
|
||||
|
||||
|
||||
class UserModels(object):
|
||||
"""Mixin for test on CAS user models"""
|
||||
@staticmethod
|
||||
def expire_user():
|
||||
"""return an expired user"""
|
||||
client = get_auth_client()
|
||||
|
||||
new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
|
||||
models.User.objects.filter(
|
||||
username=settings.CAS_TEST_USER,
|
||||
session_key=client.session.session_key
|
||||
).update(date=new_date)
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def get_user(client):
|
||||
"""return the user associated with an authenticated client"""
|
||||
return models.User.objects.get(
|
||||
username=settings.CAS_TEST_USER,
|
||||
session_key=client.session.session_key
|
||||
)
|
84
cas_server/tests/settings.py
Normal file
84
cas_server/tests/settings.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""
|
||||
Django test settings for cas_server application.
|
||||
|
||||
Generated by 'django-admin startproject' using Django 1.9.7.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/1.9/topics/settings/
|
||||
|
||||
For the full list of settings and their values, see
|
||||
https://docs.djangoproject.com/en/1.9/ref/settings/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
# Quick-start development settings - unsuitable for production
|
||||
# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
SECRET_KEY = 'changeme'
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = True
|
||||
|
||||
ALLOWED_HOSTS = []
|
||||
|
||||
|
||||
# Application definition
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'bootstrap3',
|
||||
'cas_server',
|
||||
]
|
||||
|
||||
MIDDLEWARE_CLASSES = [
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
'django.middleware.locale.LocaleMiddleware',
|
||||
]
|
||||
|
||||
ROOT_URLCONF = 'cas_server.tests.urls'
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.sqlite3',
|
||||
'NAME': ':memory:',
|
||||
}
|
||||
}
|
||||
|
||||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/1.9/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = 'en-us'
|
||||
|
||||
TIME_ZONE = 'UTC'
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
USE_L10N = True
|
||||
|
||||
USE_TZ = True
|
||||
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/1.9/howto/static-files/
|
||||
|
||||
STATIC_URL = '/static/'
|
166
cas_server/tests/test_models.py
Normal file
166
cas_server/tests/test_models.py
Normal file
@ -0,0 +1,166 @@
|
||||
# ⁻*- coding: utf-8 -*-
|
||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||
# more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License version 3
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
#
|
||||
# (c) 2016 Valentin Samir
|
||||
"""Tests module for models"""
|
||||
from cas_server.default_settings import settings
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
from django.utils import timezone
|
||||
|
||||
from datetime import timedelta
|
||||
from importlib import import_module
|
||||
|
||||
from cas_server import models
|
||||
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
|
||||
from cas_server.tests.mixin import UserModels, BaseServicePattern
|
||||
|
||||
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
|
||||
|
||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
||||
class UserTestCase(TestCase, UserModels):
|
||||
"""tests for the user models"""
|
||||
def setUp(self):
|
||||
"""Prepare the test context"""
|
||||
self.service = 'http://127.0.0.1:45678'
|
||||
self.service_pattern = models.ServicePattern.objects.create(
|
||||
name="localhost",
|
||||
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
||||
single_log_out=True
|
||||
)
|
||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||
|
||||
def test_clean_old_entries(self):
|
||||
"""test clean_old_entries"""
|
||||
# get an authenticated client
|
||||
client = self.expire_user()
|
||||
# assert the user exists before being cleaned
|
||||
self.assertEqual(len(models.User.objects.all()), 1)
|
||||
# assert the last activity date is before the expiry date
|
||||
self.assertTrue(
|
||||
self.get_user(client).date < (
|
||||
timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
|
||||
)
|
||||
)
|
||||
# delete old inactive users
|
||||
models.User.clean_old_entries()
|
||||
# assert the user has being well delete
|
||||
self.assertEqual(len(models.User.objects.all()), 0)
|
||||
|
||||
def test_clean_deleted_sessions(self):
|
||||
"""test clean_deleted_sessions"""
|
||||
# get an authenticated client
|
||||
client1 = get_auth_client()
|
||||
client2 = get_auth_client()
|
||||
# generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
|
||||
# on self.service)
|
||||
ticket = self.get_user(client1).get_ticket(
|
||||
models.ServiceTicket,
|
||||
self.service,
|
||||
self.service_pattern,
|
||||
renew=False
|
||||
)
|
||||
ticket.validate = True
|
||||
ticket.save()
|
||||
# simulated expired session being garbage collected for client1
|
||||
session = SessionStore(session_key=client1.session.session_key)
|
||||
session.flush()
|
||||
# assert the user exists before being cleaned
|
||||
self.assertTrue(self.get_user(client1))
|
||||
self.assertTrue(self.get_user(client2))
|
||||
self.assertEqual(len(models.User.objects.all()), 2)
|
||||
# session has being remove so the user of client1 is no longer authenticated
|
||||
self.assertFalse(client1.session.get("authenticated"))
|
||||
# the user a client2 should still be authenticated
|
||||
self.assertTrue(client2.session.get("authenticated"))
|
||||
# the user should be deleted
|
||||
models.User.clean_deleted_sessions()
|
||||
# assert the user with expired sessions has being well deleted but the other remain
|
||||
self.assertEqual(len(models.User.objects.all()), 1)
|
||||
self.assertFalse(models.ServiceTicket.objects.all())
|
||||
self.assertTrue(client2.session.get("authenticated"))
|
||||
|
||||
|
||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
||||
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
|
||||
"""tests for the tickets models"""
|
||||
def setUp(self):
|
||||
"""Prepare the test context"""
|
||||
self.setup_service_patterns()
|
||||
self.service = 'http://127.0.0.1:45678'
|
||||
self.service_pattern = models.ServicePattern.objects.create(
|
||||
name="localhost",
|
||||
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
||||
single_log_out=True
|
||||
)
|
||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||
|
||||
@staticmethod
|
||||
def get_ticket(
|
||||
user,
|
||||
ticket_class,
|
||||
service,
|
||||
service_pattern,
|
||||
renew=False,
|
||||
validate=False,
|
||||
validity_expired=False,
|
||||
timeout_expired=False,
|
||||
single_log_out=False,
|
||||
):
|
||||
"""Return a ticket"""
|
||||
ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
|
||||
ticket.validate = validate
|
||||
ticket.single_log_out = single_log_out
|
||||
if validity_expired:
|
||||
ticket.creation = min(
|
||||
ticket.creation,
|
||||
(timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
|
||||
)
|
||||
if timeout_expired:
|
||||
ticket.creation = min(
|
||||
ticket.creation,
|
||||
(timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
|
||||
)
|
||||
ticket.save()
|
||||
return ticket
|
||||
|
||||
def test_clean_old_service_ticket(self):
|
||||
"""test tickets clean_old_entries"""
|
||||
# ge an authenticated client
|
||||
client = get_auth_client()
|
||||
# get the user associated to the client
|
||||
user = self.get_user(client)
|
||||
# generate a ticket for that client, waiting for validation
|
||||
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
|
||||
# generate another ticket for those validation time has expired
|
||||
self.get_ticket(
|
||||
user, models.ServiceTicket,
|
||||
self.service, self.service_pattern, validity_expired=True
|
||||
)
|
||||
(httpd, host, port) = HttpParamsHandler.run()[0:3]
|
||||
service = "http://%s:%s" % (host, port)
|
||||
# generate a ticket with SLO having timeout reach
|
||||
self.get_ticket(
|
||||
user, models.ServiceTicket,
|
||||
service, self.service_pattern, timeout_expired=True,
|
||||
validate=True, single_log_out=True
|
||||
)
|
||||
# there should be 3 tickets in the db
|
||||
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
|
||||
# we call the clean_old_entries method that should delete validated non SLO ticket and
|
||||
# expired non validated ticket and send SLO for SLO expired ticket before deleting then
|
||||
models.ServiceTicket.clean_old_entries()
|
||||
params = httpd.PARAMS
|
||||
# we successfully got a SLO request
|
||||
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
|
||||
# only 1 ticket remain in the db
|
||||
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)
|
191
cas_server/tests/test_utils.py
Normal file
191
cas_server/tests/test_utils.py
Normal file
@ -0,0 +1,191 @@
|
||||
# ⁻*- coding: utf-8 -*-
|
||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||
# more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License version 3
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
#
|
||||
# (c) 2016 Valentin Samir
|
||||
"""Tests module for utils"""
|
||||
from django.test import TestCase
|
||||
|
||||
import six
|
||||
|
||||
from cas_server import utils
|
||||
|
||||
|
||||
class CheckPasswordCase(TestCase):
|
||||
"""Tests for the utils function `utils.check_password`"""
|
||||
|
||||
def setUp(self):
|
||||
"""Generate random bytes string that will be used ass passwords"""
|
||||
self.password1 = utils.gen_saml_id()
|
||||
self.password2 = utils.gen_saml_id()
|
||||
if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
|
||||
self.password1 = self.password1.encode("utf8")
|
||||
self.password2 = self.password2.encode("utf8")
|
||||
|
||||
def test_setup(self):
|
||||
"""check that generated password are bytes"""
|
||||
self.assertIsInstance(self.password1, bytes)
|
||||
self.assertIsInstance(self.password2, bytes)
|
||||
|
||||
def test_plain(self):
|
||||
"""test the plain auth method"""
|
||||
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
|
||||
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
|
||||
|
||||
def test_plain_unicode(self):
|
||||
"""test the plain auth method with unicode input"""
|
||||
self.assertTrue(
|
||||
utils.check_password(
|
||||
"plain",
|
||||
self.password1.decode("utf8"),
|
||||
self.password1.decode("utf8"),
|
||||
"utf8"
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
utils.check_password(
|
||||
"plain",
|
||||
self.password1.decode("utf8"),
|
||||
self.password2.decode("utf8"),
|
||||
"utf8"
|
||||
)
|
||||
)
|
||||
|
||||
def test_crypt(self):
|
||||
"""test the crypt auth method"""
|
||||
salts = ["$6$UVVAQvrMyXMF3FF3", "aa"]
|
||||
hashed_password1 = []
|
||||
for salt in salts:
|
||||
if six.PY3:
|
||||
hashed_password1.append(
|
||||
utils.crypt.crypt(
|
||||
self.password1.decode("utf8"),
|
||||
salt
|
||||
).encode("utf8")
|
||||
)
|
||||
else:
|
||||
hashed_password1.append(utils.crypt.crypt(self.password1, salt))
|
||||
|
||||
for hp1 in hashed_password1:
|
||||
self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8"))
|
||||
self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8")
|
||||
|
||||
def test_ldap_password_valid(self):
|
||||
"""test the ldap auth method with all the schemes"""
|
||||
salt = b"UVVAQvrMyXMF3FF3"
|
||||
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
|
||||
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
|
||||
hashed_password1 = []
|
||||
for scheme in schemes_salt:
|
||||
hashed_password1.append(
|
||||
utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8")
|
||||
)
|
||||
for scheme in schemes_nosalt:
|
||||
hashed_password1.append(
|
||||
utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8")
|
||||
)
|
||||
hashed_password1.append(
|
||||
utils.LdapHashUserPassword.hash(
|
||||
b"{CRYPT}",
|
||||
self.password1,
|
||||
b"$6$UVVAQvrMyXMF3FF3",
|
||||
charset="utf8"
|
||||
)
|
||||
)
|
||||
for hp1 in hashed_password1:
|
||||
self.assertIsInstance(hp1, bytes)
|
||||
self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8"))
|
||||
self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8"))
|
||||
|
||||
def test_ldap_password_fail(self):
|
||||
"""test the ldap auth method with malformed hash or bad schemes"""
|
||||
salt = b"UVVAQvrMyXMF3FF3"
|
||||
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
|
||||
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
|
||||
|
||||
# first try to hash with bad parameters
|
||||
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
|
||||
utils.LdapHashUserPassword.hash(b"TOTO", self.password1)
|
||||
for scheme in schemes_nosalt:
|
||||
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
|
||||
utils.LdapHashUserPassword.hash(scheme, self.password1, salt)
|
||||
for scheme in schemes_salt:
|
||||
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
|
||||
utils.LdapHashUserPassword.hash(scheme, self.password1)
|
||||
with self.assertRaises(utils.LdapHashUserPassword.BadSalt):
|
||||
utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto")
|
||||
|
||||
# then try to check hash with bad hashes
|
||||
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
|
||||
utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8")
|
||||
for scheme in schemes_salt:
|
||||
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
|
||||
utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8")
|
||||
|
||||
def test_hex(self):
|
||||
"""test all the hex_HASH method: the hashed password is a simple hash of the password"""
|
||||
hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]
|
||||
hashed_password1 = []
|
||||
for hash in hashes:
|
||||
hashed_password1.append(
|
||||
("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest())
|
||||
)
|
||||
for (method, hp1) in hashed_password1:
|
||||
self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8"))
|
||||
self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8"))
|
||||
|
||||
def test_bad_method(self):
|
||||
"""try to check password with a bad method, should raise a ValueError"""
|
||||
with self.assertRaises(ValueError):
|
||||
utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8")
|
||||
|
||||
|
||||
class UtilsTestCase(TestCase):
|
||||
"""tests for some little utils functions"""
|
||||
def test_import_attr(self):
|
||||
"""
|
||||
test the import_attr function. Feeded with a dotted path string, it should
|
||||
import the dotted module and return that last componend of the dotted path
|
||||
(function, class or variable)
|
||||
"""
|
||||
with self.assertRaises(ImportError):
|
||||
utils.import_attr('toto.titi.tutu')
|
||||
with self.assertRaises(AttributeError):
|
||||
utils.import_attr('cas_server.utils.toto')
|
||||
with self.assertRaises(ValueError):
|
||||
utils.import_attr('toto')
|
||||
self.assertEqual(
|
||||
utils.import_attr('cas_server.default_app_config'),
|
||||
'cas_server.apps.CasAppConfig'
|
||||
)
|
||||
self.assertEqual(utils.import_attr(utils), utils)
|
||||
|
||||
def test_update_url(self):
|
||||
"""
|
||||
test the update_url function. Given an url with possible GET parameter and a dict
|
||||
the function build a url with GET parameters updated by the dictionnary
|
||||
"""
|
||||
url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"})
|
||||
url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"})
|
||||
self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1")
|
||||
self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1")
|
||||
|
||||
url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"})
|
||||
self.assertEqual(url3, u"https://www.example.com?toto=2")
|
||||
|
||||
def test_crypt_salt_is_valid(self):
|
||||
"""test the function crypt_salt_is_valid who test if a crypt salt is valid"""
|
||||
self.assertFalse(utils.crypt_salt_is_valid("")) # len 0
|
||||
self.assertFalse(utils.crypt_salt_is_valid("a")) # len 1
|
||||
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
|
||||
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
|
||||
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known
|
1813
cas_server/tests/test_view.py
Normal file
1813
cas_server/tests/test_view.py
Normal file
File diff suppressed because it is too large
Load Diff
22
cas_server/tests/urls.py
Normal file
22
cas_server/tests/urls.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""cas URL Configuration
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/1.9/topics/http/urls/
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
|
||||
Class-based views
|
||||
1. Add an import: from other_app.views import Home
|
||||
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
|
||||
Including another URLconf
|
||||
1. Import the include() function: from django.conf.urls import url, include, include
|
||||
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
|
||||
"""
|
||||
from django.conf.urls import url, include
|
||||
from django.contrib import admin
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^admin/', admin.site.urls),
|
||||
url(r'^', include('cas_server.urls', namespace='cas_server')),
|
||||
]
|
180
cas_server/tests/utils.py
Normal file
180
cas_server/tests/utils.py
Normal file
@ -0,0 +1,180 @@
|
||||
# ⁻*- coding: utf-8 -*-
|
||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||
# more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License version 3
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
#
|
||||
# (c) 2016 Valentin Samir
|
||||
"""Some utils functions for tests"""
|
||||
from cas_server.default_settings import settings
|
||||
|
||||
from django.test import Client
|
||||
|
||||
import cgi
|
||||
from threading import Thread
|
||||
from lxml import etree
|
||||
from six.moves import BaseHTTPServer
|
||||
from six.moves.urllib.parse import urlparse, parse_qsl
|
||||
|
||||
from cas_server import models
|
||||
|
||||
|
||||
def copy_form(form):
|
||||
"""Copy form value into a dict"""
|
||||
params = {}
|
||||
for field in form:
|
||||
if field.value():
|
||||
params[field.name] = field.value()
|
||||
else:
|
||||
params[field.name] = ""
|
||||
return params
|
||||
|
||||
|
||||
def get_login_page_params(client=None):
|
||||
"""Return a client and the POST params for the client to login"""
|
||||
if client is None:
|
||||
client = Client()
|
||||
response = client.get('/login')
|
||||
params = copy_form(response.context["form"])
|
||||
return client, params
|
||||
|
||||
|
||||
def get_auth_client(**update):
|
||||
"""return a authenticated client"""
|
||||
client, params = get_login_page_params()
|
||||
params["username"] = settings.CAS_TEST_USER
|
||||
params["password"] = settings.CAS_TEST_PASSWORD
|
||||
params.update(update)
|
||||
|
||||
client.post('/login', params)
|
||||
assert client.session.get("authenticated")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def get_user_ticket_request(service):
|
||||
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
|
||||
client = get_auth_client()
|
||||
response = client.get("/login", {"service": service})
|
||||
ticket_value = response['Location'].split('ticket=')[-1]
|
||||
user = models.User.objects.get(
|
||||
username=settings.CAS_TEST_USER,
|
||||
session_key=client.session.session_key
|
||||
)
|
||||
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
||||
return (user, ticket, client)
|
||||
|
||||
|
||||
def get_validated_ticket(service):
|
||||
"""Return a tick that has being already validated. Used to test SLO"""
|
||||
(ticket, auth_client) = get_user_ticket_request(service)[1:3]
|
||||
|
||||
client = Client()
|
||||
response = client.get('/validate', {'ticket': ticket.value, 'service': service})
|
||||
assert (response.status_code == 200)
|
||||
assert (response.content == b'yes\ntest\n')
|
||||
|
||||
ticket = models.ServiceTicket.objects.get(value=ticket.value)
|
||||
return (auth_client, ticket)
|
||||
|
||||
|
||||
def get_pgt():
|
||||
"""return a dict contening a service, user and PGT ticket for this service"""
|
||||
(httpd, host, port) = HttpParamsHandler.run()[0:3]
|
||||
service = "http://%s:%s" % (host, port)
|
||||
|
||||
(user, ticket) = get_user_ticket_request(service)[:2]
|
||||
|
||||
client = Client()
|
||||
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
|
||||
params = httpd.PARAMS
|
||||
|
||||
params["service"] = service
|
||||
params["user"] = user
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_proxy_ticket(service):
|
||||
"""Return a ProxyTicket waiting for validation"""
|
||||
params = get_pgt()
|
||||
|
||||
# get a proxy ticket
|
||||
client = Client()
|
||||
response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
|
||||
root = etree.fromstring(response.content)
|
||||
proxy_ticket = root.xpath(
|
||||
"//cas:proxyTicket",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
proxy_ticket = proxy_ticket[0].text
|
||||
ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
|
||||
return ticket
|
||||
|
||||
|
||||
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
"""
|
||||
A simple http server that return 200 on GET or POST
|
||||
and store GET or POST parameters. Used in unit tests
|
||||
"""
|
||||
|
||||
def do_GET(self):
|
||||
"""Called on a GET request on the BaseHTTPServer"""
|
||||
self.send_response(200)
|
||||
self.send_header(b"Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"ok")
|
||||
url = urlparse(self.path)
|
||||
params = dict(parse_qsl(url.query))
|
||||
self.server.PARAMS = params
|
||||
|
||||
def do_POST(self):
|
||||
"""Called on a POST request on the BaseHTTPServer"""
|
||||
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
|
||||
if ctype == 'multipart/form-data':
|
||||
postvars = cgi.parse_multipart(self.rfile, pdict)
|
||||
elif ctype == 'application/x-www-form-urlencoded':
|
||||
length = int(self.headers.get('content-length'))
|
||||
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
|
||||
else:
|
||||
postvars = {}
|
||||
self.server.PARAMS = postvars
|
||||
|
||||
def log_message(self, *args):
|
||||
"""silent any log message"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def run(cls):
|
||||
"""Run a BaseHTTPServer using this class as handler"""
|
||||
server_class = BaseHTTPServer.HTTPServer
|
||||
httpd = server_class(("127.0.0.1", 0), cls)
|
||||
(host, port) = httpd.socket.getsockname()
|
||||
|
||||
def lauch():
|
||||
"""routine to lauch in a background thread"""
|
||||
httpd.handle_request()
|
||||
httpd.server_close()
|
||||
|
||||
httpd_thread = Thread(target=lauch)
|
||||
httpd_thread.daemon = True
|
||||
httpd_thread.start()
|
||||
return (httpd, host, port)
|
||||
|
||||
|
||||
class Http404Handler(HttpParamsHandler):
|
||||
"""A simple http server that always return 404 not found. Used in unit tests"""
|
||||
def do_GET(self):
|
||||
"""Called on a GET request on the BaseHTTPServer"""
|
||||
self.send_response(404)
|
||||
self.send_header(b"Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"error 404 not found")
|
||||
|
||||
def do_POST(self):
|
||||
"""Called on a POST request on the BaseHTTPServer"""
|
||||
return self.do_GET()
|
Reference in New Issue
Block a user