Merge branch 'dev' into federate

This commit is contained in:
Valentin Samir
2016-06-28 00:34:31 +02:00
28 changed files with 1472 additions and 779 deletions

View File

@ -9,4 +9,4 @@
#
# (c) 2015 Valentin Samir
default_app_config = 'cas_server.apps.AppConfig'
default_app_config = 'cas_server.apps.CasAppConfig'

View File

@ -14,9 +14,9 @@ from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, Servi
from .models import Username, ReplaceAttributName, ReplaceAttributValue, FilterAttributValue
from .forms import TicketForm
tickets_readonly_fields = ('validate', 'service', 'service_pattern',
TICKETS_READONLY_FIELDS = ('validate', 'service', 'service_pattern',
'creation', 'renew', 'single_log_out', 'value')
tickets_fields = ('validate', 'service', 'service_pattern',
TICKETS_FIELDS = ('validate', 'service', 'service_pattern',
'creation', 'renew', 'single_log_out')
@ -25,8 +25,8 @@ class ServiceTicketInline(admin.TabularInline):
model = ServiceTicket
extra = 0
form = TicketForm
readonly_fields = tickets_readonly_fields
fields = tickets_fields
readonly_fields = TICKETS_READONLY_FIELDS
fields = TICKETS_FIELDS
class ProxyTicketInline(admin.TabularInline):
@ -34,8 +34,8 @@ class ProxyTicketInline(admin.TabularInline):
model = ProxyTicket
extra = 0
form = TicketForm
readonly_fields = tickets_readonly_fields
fields = tickets_fields
readonly_fields = TICKETS_READONLY_FIELDS
fields = TICKETS_FIELDS
class ProxyGrantingInline(admin.TabularInline):
@ -43,8 +43,8 @@ class ProxyGrantingInline(admin.TabularInline):
model = ProxyGrantingTicket
extra = 0
form = TicketForm
readonly_fields = tickets_readonly_fields
fields = tickets_fields[1:]
readonly_fields = TICKETS_READONLY_FIELDS
fields = TICKETS_FIELDS[1:]
class UserAdmin(admin.ModelAdmin):

View File

@ -2,6 +2,6 @@ from django.utils.translation import ugettext_lazy as _
from django.apps import AppConfig
class AppConfig(AppConfig):
class CasAppConfig(AppConfig):
name = 'cas_server'
verbose_name = _('Central Authentication Service')

View File

@ -15,10 +15,10 @@ from django.contrib.auth import get_user_model
from django.utils import timezone
from datetime import timedelta
try:
try: # pragma: no cover
import MySQLdb
import MySQLdb.cursors
import crypt
from utils import check_password
except ImportError:
MySQLdb = None
@ -31,14 +31,14 @@ class AuthUser(object):
def test_password(self, password):
"""test `password` agains the user"""
raise NotImplemented()
raise NotImplementedError()
def attributs(self):
"""return a dict of user attributes"""
raise NotImplemented()
raise NotImplementedError()
class DummyAuthUser(AuthUser):
class DummyAuthUser(AuthUser): # pragma: no cover
"""A Dummy authentication class"""
def __init__(self, username):
@ -62,14 +62,14 @@ class TestAuthUser(AuthUser):
def test_password(self, password):
"""test `password` agains the user"""
return self.username == "test" and password == "test"
return self.username == settings.CAS_TEST_USER and password == settings.CAS_TEST_PASSWORD
def attributs(self):
"""return a dict of user attributes"""
return {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
return settings.CAS_TEST_ATTRIBUTES
class MysqlAuthUser(AuthUser):
class MysqlAuthUser(AuthUser): # pragma: no cover
"""A mysql auth class: authentication user agains a mysql database"""
user = None
@ -94,30 +94,25 @@ class MysqlAuthUser(AuthUser):
def test_password(self, password):
"""test `password` agains the user"""
if not self.user:
return False
if self.user:
return check_password(
settings.CAS_SQL_PASSWORD_CHECK,
password,
self.user["password"],
settings.CAS_SQL_DBCHARSET
)
else:
if settings.CAS_SQL_PASSWORD_CHECK == "plain":
return password == self.user["password"]
elif settings.CAS_SQL_PASSWORD_CHECK == "crypt":
if self.user["password"].startswith('$'):
salt = '$'.join(self.user["password"].split('$', 3)[:-1])
return crypt.crypt(password, salt) == self.user["password"]
else:
return crypt.crypt(
password,
self.user["password"][:2]
) == self.user["password"]
return False
def attributs(self):
"""return a dict of user attributes"""
if not self.user:
return {}
else:
if self.user:
return self.user
else:
return {}
class DjangoAuthUser(AuthUser):
class DjangoAuthUser(AuthUser): # pragma: no cover
"""A django auth class: authenticate user agains django internal users"""
user = None
@ -131,21 +126,20 @@ class DjangoAuthUser(AuthUser):
def test_password(self, password):
"""test `password` agains the user"""
if not self.user:
return False
else:
if self.user:
return self.user.check_password(password)
else:
return False
def attributs(self):
"""return a dict of user attributes"""
if not self.user:
return {}
else:
if self.user:
attr = {}
for field in self.user._meta.fields:
attr[field.attname] = getattr(self.user, field.attname)
return attr
else:
return {}
class CASFederateAuth(AuthUser):
user = None

View File

@ -78,6 +78,17 @@ setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS '
'password, users.* FROM users WHERE user = %s')
setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain
setting_default('CAS_TEST_USER', 'test')
setting_default('CAS_TEST_PASSWORD', 'test')
setting_default(
'CAS_TEST_ATTRIBUTES',
{
'nom': 'Nymous',
'prenom': 'Ano',
'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']
}
)
setting_default('CAS_FEDERATE', False)
# A dict of "provider suffix" -> (provider CAS server url, CAS version, verbose name)

View File

@ -1,55 +1,52 @@
function cas_login(cas_server_login, service, login_service, callback){
url = cas_server_login + '?service=' + encodeURIComponent(service);
var url = cas_server_login + "?service=" + encodeURIComponent(service);
$.ajax({
type: 'GET',
url:url,
beforeSend: function (request) {
type: "GET",
url,
beforeSend(request) {
request.setRequestHeader("X-AJAX", "1");
},
xhrFields: {
withCredentials: true
},
success: function(data, textStatus, request){
if(data.status == 'success'){
success(data, textStatus, request){
if(data.status === "success"){
$.ajax({
type: 'GET',
type: "GET",
url: data.url,
xhrFields: {
withCredentials: true
},
success: callback,
error: function (request, textStatus, errorThrown) {},
error(request, textStatus, errorThrown) {},
});
} else {
if(data.detail == "login required"){
window.location.href = cas_server_login + '?service=' + encodeURIComponent(login_service);
if(data.detail === "login required"){
window.location.href = cas_server_login + "?service=" + encodeURIComponent(login_service);
} else {
alert('error: ' + data.messages[1].message);
alert("error: " + data.messages[1].message);
}
}
},
error: function (request, textStatus, errorThrown) {},
error(request, textStatus, errorThrown) {},
});
}
function cas_logout(cas_server_logout){
$.ajax({
type: 'GET',
url:cas_server_logout,
beforeSend: function (request) {
type: "GET",
url: cas_server_logout,
beforeSend(request) {
request.setRequestHeader("X-AJAX", "1");
},
xhrFields: {
withCredentials: true
},
error: function (request, textStatus, errorThrown) {},
success: function(data, textStatus, request){
if(data.status == 'error'){
alert('error: ' + data.messages[1].message);
error(request, textStatus, errorThrown) {},
success(data, textStatus, request){
if(data.status === "error"){
alert("error: " + data.messages[1].message);
}
},
});
}

View File

@ -43,14 +43,14 @@ body {
@media screen and (max-width: 680px) {
#app-name {
margin: 0px;
margin: 0;
}
#app-name img {
display: block;
margin: auto;
}
body {
padding-top: 0px;
padding-bottom: 0px;
padding-top: 0;
padding-bottom: 0;
}
}

943
cas_server/tests.py Normal file
View File

@ -0,0 +1,943 @@
from .default_settings import settings
import django
from django.test import TestCase
from django.test import Client
import re
import six
import random
from lxml import etree
from six.moves import range
from cas_server import models
from cas_server import utils
def copy_form(form):
"""Copy form value into a dict"""
params = {}
for field in form:
if field.value():
params[field.name] = field.value()
else:
params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params
def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params)
return client
def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client()
response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
return (user, ticket)
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)
client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = utils.PGTUrlHandler.PARAMS.copy()
params["service"] = service
params["user"] = user
return params
class CheckPasswordCase(TestCase):
"""Tests for the utils function `utils.check_password`"""
def setUp(self):
"""Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes):
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
"""check that generated password are bytes"""
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
"""test the plain auth method"""
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_crypt(self):
"""test the crypt auth method"""
if six.PY3:
hashed_password1 = utils.crypt.crypt(
self.password1.decode("utf8"),
"$6$UVVAQvrMyXMF3FF3"
).encode("utf8")
else:
hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
def test_ldap_ssha(self):
"""test the ldap auth method with a {SSHA} scheme"""
salt = b"UVVAQvrMyXMF3FF3"
hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
self.assertIsInstance(hashed_password1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
def test_hex_md5(self):
"""test the hex_md5 auth method"""
hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
def test_hex_sha512(self):
"""test the hex_sha512 auth method"""
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
self.assertTrue(
utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
)
self.assertFalse(
utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
)
class LoginTestCase(TestCase):
"""Tests for the login view"""
def setUp(self):
"""
Prepare the test context:
* set the auth class to 'cas_server.auth.TestAuthUser'
* create a service pattern for https://www.example.com/**
* Set the service pattern to return all user attributes
"""
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
# For general purpose testing
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$",
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
# For testing the restrict_users attributes
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
name="restrict_user_fail",
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
)
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
# For testing the user attributes filtering conditions
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
# For testing the user_field attributes
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid"
)
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="nom"
)
def assert_logged(self, client, response, warn=False, code=200):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["warn"] is warn)
self.assertTrue(client.session["authenticated"] is True)
self.assertTrue(
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertTrue(params['lt'] in client.session['lt'])
response = client.post('/login', params)
self.assert_logged(client, response)
# LoginTicket conssumed
self.assertTrue(params['lt'] not in client.session['lt'])
def test_login_view_post_goodpass_goodlt_warn(self):
"""Test a successul login requesting to be warned before creating services tickets"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["warn"] = "on"
response = client.post('/login', params)
self.assert_logged(client, response, warn=True)
def test_lt_max(self):
"""Check we only keep the last 100 Login Ticket for a user"""
client, params = get_login_page_params()
current_lt = params["lt"]
i_in_test = random.randint(0, 100)
i_not_in_test = random.randint(100, 150)
for i in range(150):
if i == i_in_test:
self.assertTrue(current_lt in client.session['lt'])
if i == i_not_in_test:
self.assertTrue(current_lt not in client.session['lt'])
self.assertTrue(len(client.session['lt']) <= 100)
client, params = get_login_page_params(client)
self.assertTrue(len(client.session['lt']) <= 100)
def test_login_view_post_badlt(self):
"""Login attempt with a bad LoginTicket"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["lt"] = 'LT-random'
response = client.post('/login', params)
self.assert_login_failed(client, response)
self.assertTrue(b"Invalid login ticket" in response.content)
def test_login_view_post_badpass_good_lt(self):
"""Login attempt with a bad password"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = "test2"
response = client.post('/login', params)
self.assert_login_failed(client, response)
self.assertTrue(
(
b"The credentials you provided cannot be "
b"determined to be authentic"
) in response.content
)
def assert_ticket_attributes(self, client, ticket_value):
"""check the ticket attributes in the db"""
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
self.assertTrue(user)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
self.assertEqual(ticket.user, user)
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
self.assertEqual(ticket.validate, False)
self.assertEqual(ticket.service_pattern, self.service_pattern)
def assert_service_ticket(self, client, response):
"""check that a ticket is well emited when requested on a allowed service"""
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header('Location'))
self.assertTrue(
response['Location'].startswith(
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
)
)
ticket_value = response['Location'].split('ticket=')[-1]
self.assert_ticket_attributes(client, ticket_value)
def test_view_login_get_allowed_service(self):
"""Request a ticket for an allowed service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication required by service "
b"example (https://www.example.com)"
) in response.content
)
def test_view_login_get_denied_service(self):
"""Request a ticket for an denied service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.net")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.net non allowed" in response.content)
def test_view_login_get_auth_allowed_service(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client()
response = client.get("/login?service=https://www.example.com")
self.assert_service_ticket(client, response)
def test_view_login_get_auth_allowed_service_warn(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client(warn="on")
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication has been required by service "
b"example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
response = client.post("/login", params)
self.assert_service_ticket(client, response)
def test_view_login_get_auth_denied_service(self):
"""Request a ticket for a not allowed service by an authenticated client"""
client = get_auth_client()
response = client.get("/login?service=https://www.example.org")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
def test_user_logged_not_in_db(self):
"""If the user is logged but has been delete from the database, it should be logged out"""
client = get_auth_client()
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).delete()
response = client.get("/login")
self.assert_login_failed(client, response, code=302)
if django.VERSION < (1, 9):
self.assertEqual(response["Location"], "http://testserver/login")
else:
self.assertEqual(response["Location"], "/login?")
def test_service_restrict_user(self):
"""Testing the restric user capability fro a service"""
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Username non allowed" in response.content)
service = "https://restrict_user_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_filter(self):
"""Test the filtering on user attributes"""
service = "https://filter_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"User charateristics non allowed" in response.content)
service = "https://filter_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_user_field(self):
"""Test using a user attribute as username: case on if the attribute exists or not"""
service = "https://field_needed_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"The attribut uid is needed to use that service" in response.content)
service = "https://field_needed_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_gateway(self):
"""test gateway parameter"""
# First with an authenticated client that fail to get a ticket for a service
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
# second for an user not yet authenticated on a valid service
client = Client()
response = client.get('/login', {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
class LogoutTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
def test_logout_view(self):
client = get_auth_client()
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
response = client.get("/logout")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged out from "
b"the Central Authentication Service"
) in response.content
)
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_url(self):
client = get_auth_client()
response = client.get('/logout?url=https://www.example.com')
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_service(self):
client = get_auth_client()
response = client.get('/logout?service=https://www.example.com')
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
class AuthTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'https://www.example.com'
models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
def test_auth_view_goodpass(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\n')
def test_auth_view_badpass(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': 'badpass',
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badservice(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': 'https://www.example.org',
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsecret(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'badsecret'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsettings(self):
settings.CAS_AUTH_SHARED_SECRET = None
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
class ValidateTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'https://www.example.com'
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_view_ok(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\ntest\n')
def test_validate_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/validate',
{'ticket': ticket.value, 'service': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_validate_view_badticket(self):
get_user_ticket_request(self.service)
client = Client()
response = client.get(
'/validate',
{'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
class ValidateServiceTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_service_view_ok(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
def test_validate_service_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
bad_service = "https://www.example.org"
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
self.assertEqual(error[0].text, bad_service)
def test_validate_service_view_badticket_goodprefix(self):
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(error[0].text, 'ticket not found')
def test_validate_service_view_badticket_badprefix(self):
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "RANDOM"
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(error[0].text, bad_ticket)
def test_validate_service_view_ok_pgturl(self):
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
pgt_params = utils.PGTUrlHandler.PARAMS.copy()
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
pgtiou = root.xpath(
"//cas:proxyGrantingTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(pgtiou), 1)
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
self.assertTrue("pgtId" in pgt_params)
def test_validate_service_pgturl_bad_proxy_callback(self):
self.service_pattern.proxy_callback = False
self.service_pattern.save()
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
self.assertEqual(error[0].text, "callback url not allowed by configuration")
class ProxyTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy=True,
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_proxy_ok(self):
params = get_pgt()
# get a proxy ticket
client1 = Client()
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertTrue(sucess)
proxy_ticket = root.xpath(
"//cas:proxyTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(proxy_ticket), 1)
proxy_ticket = proxy_ticket[0].text
# validate the proxy ticket
client2 = Client()
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
# check that the proxy is send to the end service
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxies), 1)
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxy), 1)
self.assertEqual(proxy[0].text, params["service"])
# same tests than those for serviceValidate
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
def test_validate_proxy_bad(self):
params = get_pgt()
# bad PGT
client1 = Client()
response = client1.get(
'/proxy',
{
'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
'targetService': params['service']
}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(
error[0].text,
"PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
)
# bad targetService
client2 = Client()
response = client2.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
self.assertEqual(error[0].text, "https://www.example.org")
# service do not allow proxy ticket
self.service_pattern.proxy = False
self.service_pattern.save()
client3 = Client()
response = client3.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': params['service']}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
self.assertEqual(
error[0].text,
'the service %s do not allow proxy ticket' % params['service']
)

View File

@ -14,7 +14,7 @@ from django.conf.urls import patterns, url
from django.views.generic import RedirectView
from django.views.decorators.debug import sensitive_post_parameters, sensitive_variables
import views
from cas_server import views
urlpatterns = patterns(
'',

View File

@ -19,14 +19,15 @@ from django.contrib import messages
import random
import string
import json
import hashlib
import crypt
import base64
import six
from threading import Thread
from importlib import import_module
from datetime import datetime, timedelta
try:
from urlparse import urlparse, urlunparse, parse_qsl
from urllib import urlencode
except ImportError:
from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
def context(params):
@ -34,7 +35,7 @@ def context(params):
return params
def JsonResponse(request, data):
def json_response(request, data):
data["messages"] = []
for msg in messages.get_messages(request):
data["messages"].append({'message': msg.message, 'level': msg.level_tag})
@ -120,9 +121,9 @@ def update_url(url, params):
query = dict(parse_qsl(url_parts[4]))
query.update(params)
url_parts[4] = urlencode(query)
for i in range(len(url_parts)):
if not isinstance(url_parts[i], bytes):
url_parts[i] = url_parts[i].encode('utf-8')
for i, url_part in enumerate(url_parts):
if not isinstance(url_part, bytes):
url_parts[i] = url_part.encode('utf-8')
return urlunparse(url_parts).decode('utf-8')
@ -190,3 +191,207 @@ def get_tuple(tuple, index, default=None):
return tuple[index]
except IndexError:
return default
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
PARAMS = {}
def do_GET(self):
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
PGTUrlHandler.PARAMS.update(params)
def log_message(self, *args):
return
@staticmethod
def run():
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), PGTUrlHandler)
(host, port) = httpd.socket.getsockname()
def lauch():
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd_thread, host, port)
class LdapHashUserPassword(object):
"""Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""
schemes_salt = {b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}", b"{CRYPT}"}
schemes_nosalt = {b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"}
_schemes_to_hash = {
b"{SMD5}": hashlib.md5,
b"{MD5}": hashlib.md5,
b"{SSHA}": hashlib.sha1,
b"{SHA}": hashlib.sha1,
b"{SSHA256}": hashlib.sha256,
b"{SHA256}": hashlib.sha256,
b"{SSHA384}": hashlib.sha384,
b"{SHA384}": hashlib.sha384,
b"{SSHA512}": hashlib.sha512,
b"{SHA512}": hashlib.sha512
}
_schemes_to_len = {
b"{SMD5}": 16,
b"{SSHA}": 20,
b"{SSHA256}": 32,
b"{SSHA384}": 48,
b"{SSHA512}": 64,
}
class BadScheme(ValueError):
"""Error raised then the hash scheme is not in schemes_salt + schemes_nosalt"""
pass
class BadHash(ValueError):
"""Error raised then the hash is too short"""
pass
class BadSalt(ValueError):
"""Error raised then with the scheme {CRYPT} the salt is invalid"""
pass
@classmethod
def _raise_bad_scheme(cls, scheme, valid, msg):
"""
Raise BadScheme error for `scheme`, possible valid scheme are
in `valid`, the error message is `msg`
"""
valid_schemes = [s.decode() for s in valid]
valid_schemes.sort()
raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes)))
@classmethod
def _test_scheme(cls, scheme):
"""Test if a scheme is valide or raise BadScheme"""
if scheme not in cls.schemes_salt and scheme not in cls.schemes_nosalt:
cls._raise_bad_scheme(
scheme,
cls.schemes_salt | cls.schemes_nosalt,
"The scheme %r is not valid. Valide schemes are %s."
)
@classmethod
def _test_scheme_salt(cls, scheme):
"""Test if the scheme need a salt or raise BadScheme"""
if scheme not in cls.schemes_salt:
cls._raise_bad_scheme(
scheme,
cls.schemes_salt,
"The scheme %r is only valid without a salt. Valide schemes with salt are %s."
)
@classmethod
def _test_scheme_nosalt(cls, scheme):
"""Test if the scheme need no salt or raise BadScheme"""
if scheme not in cls.schemes_nosalt:
cls._raise_bad_scheme(
scheme,
cls.schemes_nosalt,
"The scheme %r is only valid with a salt. Valide schemes without salt are %s."
)
@classmethod
def hash(cls, scheme, password, salt=None, charset="utf8"):
"""
Hash `password` with `scheme` using `salt`.
This three variable beeing encoded in `charset`.
"""
scheme = scheme.upper()
cls._test_scheme(scheme)
if salt is None or salt == b"":
salt = b""
cls._test_scheme_nosalt(scheme)
elif salt is not None:
cls._test_scheme_salt(scheme)
try:
return scheme + base64.b64encode(
cls._schemes_to_hash[scheme](password + salt).digest() + salt
)
except KeyError:
if six.PY3:
password = password.decode(charset)
salt = salt.decode(charset)
hashed_password = crypt.crypt(password, salt)
if hashed_password is None:
raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
if six.PY3:
hashed_password = hashed_password.encode(charset)
return scheme + hashed_password
@classmethod
def get_scheme(cls, hashed_passord):
"""Return the scheme of `hashed_passord` or raise BadHash"""
if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord:
raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord)
scheme = hashed_passord.split(b'}', 1)[0]
scheme = scheme.upper() + b"}"
return scheme
@classmethod
def get_salt(cls, hashed_passord):
"""Return the salt of `hashed_passord` possibly empty"""
scheme = cls.get_scheme(hashed_passord)
cls._test_scheme(scheme)
if scheme in cls.schemes_nosalt:
return b""
elif scheme == b'{CRYPT}':
return b'$'.join(hashed_passord.split(b'$', 3)[:-1])
else:
hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
if len(hashed_passord) < cls._schemes_to_len[scheme]:
raise cls.BadHash("Hash too short for the scheme %s" % scheme)
return hashed_passord[cls._schemes_to_len[scheme]:]
def check_password(method, password, hashed_password, charset):
"""
Check that `password` match `hashed_password` using `method`,
assuming the encoding is `charset`.
"""
if not isinstance(password, six.binary_type):
password = password.encode(charset)
if not isinstance(hashed_password, six.binary_type):
hashed_password = hashed_password.encode(charset)
if method == "plain":
return password == hashed_password
elif method == "crypt":
if hashed_password.startswith(b'$'):
salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
elif hashed_password.startswith(b'_'):
salt = hashed_password[:9]
else:
salt = hashed_password[:2]
if six.PY3:
password = password.decode(charset)
salt = salt.decode(charset)
hashed_password = hashed_password.decode(charset)
crypted_password = crypt.crypt(password, salt)
if crypted_password is None:
raise ValueError("System crypt implementation do not support the salt %r" % salt)
return crypted_password == hashed_password
elif method == "ldap":
scheme = LdapHashUserPassword.get_scheme(hashed_password)
salt = LdapHashUserPassword.get_salt(hashed_password)
return LdapHashUserPassword.hash(scheme, password, salt, charset=charset) == hashed_password
elif (
method.startswith("hex_") and
method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"}
):
return getattr(
hashlib,
method[4:]
)(password).hexdigest().encode("ascii") == hashed_password.lower()
else:
raise ValueError("Unknown password method check %r" % method)

View File

@ -23,6 +23,7 @@ from django.views.decorators.csrf import csrf_exempt
from django.middleware.csrf import CsrfViewMiddleware
from django.views.generic import View
import re
import logging
import pprint
import requests
@ -34,7 +35,7 @@ import cas_server.utils as utils
import cas_server.forms as forms
import cas_server.models as models
from .utils import JsonResponse
from .utils import json_response
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket
from .models import ServicePattern
from .federate import CASFederateValidateUser
@ -63,12 +64,12 @@ class AttributesMixin(object):
class LogoutMixin(object):
"""destroy CAS session utils"""
def logout(self, all=False):
def logout(self, all_session=False):
"""effectively destroy CAS session"""
session_nb = 0
username = self.request.session.get("username")
if username:
if all:
if all_session:
logger.info("Logging out user %s from all of they sessions." % username)
else:
logger.info("Logging out user %s." % username)
@ -91,8 +92,8 @@ class LogoutMixin(object):
# if user not found in database, flush the session anyway
self.request.session.flush()
# If all is set logout user from alternative sessions
if all:
# If all_session is set logout user from alternative sessions
if all_session:
for user in models.User.objects.filter(username=username):
session = SessionStore(session_key=user.session_key)
session.flush()
@ -110,6 +111,7 @@ class LogoutView(View, LogoutMixin):
service = None
def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request
self.service = request.GET.get('service')
self.url = request.GET.get('url')
@ -170,13 +172,13 @@ class LogoutView(View, LogoutMixin):
'url': url,
'session_nb': session_nb
}
return JsonResponse(request, data)
return json_response(request, data)
else:
return redirect("cas_server:login")
else:
if self.ajax:
data = {'status': 'success', 'detail': 'logout', 'session_nb': session_nb}
return JsonResponse(request, data)
return json_response(request, data)
else:
return render(
request,
@ -290,12 +292,10 @@ class LoginView(View, LogoutMixin):
USER_NOT_AUTHENTICATED = 6
def init_post(self, request):
"""Initialize POST received parameters"""
self.request = request
self.service = request.POST.get('service')
if request.POST.get('renew') and request.POST['renew'] != "False":
self.renew = True
else:
self.renew = False
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
self.gateway = request.POST.get('gateway')
self.method = request.POST.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META
@ -306,15 +306,19 @@ class LoginView(View, LogoutMixin):
self.username = request.POST.get('username')
self.ticket = request.POST.get('ticket')
def check_lt(self):
# save LT for later check
lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed)
def gen_lt(self):
"""Generate a new LoginTicket and add it to the list of valid LT for the user"""
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:]
def check_lt(self):
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
# save LT for later check
lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed)
self.gen_lt()
# check if send LT is valid
if lt_valid is None or lt_send not in lt_valid:
return False
@ -339,7 +343,7 @@ class LoginView(View, LogoutMixin):
username=self.request.session['username'],
session_key=self.request.session.session_key
)
self.user.save()
self.user.save() # pragma: no cover (should not happend)
except models.User.DoesNotExist:
self.user = models.User.objects.create(
username=self.request.session['username'],
@ -355,10 +359,15 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_ALREADY_LOGGED:
pass
else:
raise EnvironmentError("invalid output for LoginView.process_post")
raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
return self.common()
def process_post(self, pytest=False):
def process_post(self):
"""
Analyse the POST request:
* check that the LoginTicket is valid
* check that the user sumited credentials are valid
"""
if not self.check_lt():
values = self.request.POST.copy()
# if not set a new LT and fail
@ -385,12 +394,10 @@ class LoginView(View, LogoutMixin):
return self.USER_ALREADY_LOGGED
def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request
self.service = request.GET.get('service')
if request.GET.get('renew') and request.GET['renew'] != "False":
self.renew = True
else:
self.renew = False
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META
@ -410,15 +417,16 @@ class LoginView(View, LogoutMixin):
return self.common()
def process_get(self):
# generate a new LT if none is present
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
"""Analyse the GET request"""
# generate a new LT
self.gen_lt()
if not self.request.session.get("authenticated") or self.renew:
self.init_form()
return self.USER_NOT_AUTHENTICATED
return self.USER_AUTHENTICATED
def init_form(self, values=None):
"""Initialization of the good form depending of POST and GET parameters"""
form_initial = {
'service': self.service,
'method': self.method,
@ -459,7 +467,7 @@ class LoginView(View, LogoutMixin):
)
if self.ajax:
data = {"status": "error", "detail": "confirmation needed"}
return JsonResponse(self.request, data)
return json_response(self.request, data)
else:
warn_form = forms.WarnForm(initial={
'service': self.service,
@ -486,7 +494,7 @@ class LoginView(View, LogoutMixin):
return HttpResponseRedirect(redirect_url)
else:
data = {"status": "success", "detail": "auth", "url": redirect_url}
return JsonResponse(self.request, data)
return json_response(self.request, data)
except ServicePattern.DoesNotExist:
error = 1
messages.add_message(
@ -530,7 +538,7 @@ class LoginView(View, LogoutMixin):
)
else:
data = {"status": "error", "detail": "auth", "code": error}
return JsonResponse(self.request, data)
return json_response(self.request, data)
def authenticated(self):
"""Processing authenticated users"""
@ -552,7 +560,7 @@ class LoginView(View, LogoutMixin):
"detail": "login required",
"url": utils.reverse_params("cas_server:login", params=self.request.GET)
}
return JsonResponse(self.request, data)
return json_response(self.request, data)
else:
return utils.redirect_params("cas_server:login", params=self.request.GET)
@ -562,7 +570,7 @@ class LoginView(View, LogoutMixin):
else:
if self.ajax:
data = {"status": "success", "detail": "logged"}
return JsonResponse(self.request, data)
return json_response(self.request, data)
else:
return render(
self.request,
@ -605,7 +613,7 @@ class LoginView(View, LogoutMixin):
"detail": "login required",
"url": utils.reverse_params("cas_server:login", params=self.request.GET)
}
return JsonResponse(self.request, data)
return json_response(self.request, data)
else:
if settings.CAS_FEDERATE:
if self.username and self.ticket:
@ -824,7 +832,10 @@ class ValidateService(View, AttributesMixin):
params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field
)
if self.pgt_url and self.pgt_url.startswith("https://"):
if self.pgt_url and (
self.pgt_url.startswith("https://") or
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
):
return self.process_pgturl(params)
else:
logger.info(