Add unit tests for when CAS_FEDERATE is True
Also fix some unicode related bugs
This commit is contained in:
@ -123,15 +123,18 @@ class LogoutView(View, LogoutMixin):
|
||||
self.init_get(request)
|
||||
# if CAS federation mode is enable, bakup the provider before flushing the sessions
|
||||
if settings.CAS_FEDERATE:
|
||||
component = self.request.session.get("username").split('@')
|
||||
provider = component[-1]
|
||||
auth = CASFederateValidateUser(provider, service_url="")
|
||||
if "username" in self.request.session:
|
||||
component = self.request.session["username"].split('@')
|
||||
provider = component[-1]
|
||||
auth = CASFederateValidateUser(provider, service_url="")
|
||||
else:
|
||||
auth = None
|
||||
session_nb = self.logout(self.request.GET.get("all"))
|
||||
# if CAS federation mode is enable, redirect to user CAS logout page
|
||||
if settings.CAS_FEDERATE:
|
||||
params = utils.copy_params(request.GET)
|
||||
url = utils.update_url(auth.get_logout_url(), params)
|
||||
if url:
|
||||
if auth is not None:
|
||||
params = utils.copy_params(request.GET)
|
||||
url = utils.update_url(auth.get_logout_url(), params)
|
||||
return HttpResponseRedirect(url)
|
||||
# if service is set, redirect to service after logout
|
||||
if self.service:
|
||||
@ -195,7 +198,7 @@ class FederateAuth(View):
|
||||
|
||||
@staticmethod
|
||||
def get_cas_client(request, provider):
|
||||
if provider in settings.CAS_FEDERATE_PROVIDERS:
|
||||
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true)
|
||||
service_url = utils.get_current_url(request, {"ticket", "provider"})
|
||||
return CASFederateValidateUser(provider, service_url)
|
||||
|
||||
@ -207,14 +210,14 @@ class FederateAuth(View):
|
||||
auth = self.get_cas_client(request, provider)
|
||||
try:
|
||||
auth.clean_sessions(request.POST['logoutRequest'])
|
||||
except KeyError:
|
||||
except (KeyError, AttributeError):
|
||||
pass
|
||||
return HttpResponse("ok")
|
||||
# else, a User is trying to log in using an identity provider
|
||||
else:
|
||||
# Manually checking for csrf to protect the code below
|
||||
reason = CsrfViewMiddleware().process_view(request, None, (), {})
|
||||
if reason is not None:
|
||||
if reason is not None: # pragma: no cover (csrf checks are disabled during tests)
|
||||
return reason # Failed the test, stop here.
|
||||
form = forms.FederateSelect(request.POST)
|
||||
if form.is_valid():
|
||||
@ -252,7 +255,7 @@ class FederateAuth(View):
|
||||
ticket = request.GET['ticket']
|
||||
if auth.verify_ticket(ticket):
|
||||
params = utils.copy_params(request.GET, ignore={"ticket"})
|
||||
username = "%s@%s" % (auth.username, auth.provider)
|
||||
username = u"%s@%s" % (auth.username, auth.provider)
|
||||
request.session["federate_username"] = username
|
||||
request.session["federate_ticket"] = ticket
|
||||
auth.register_slo(username, request.session.session_key, ticket)
|
||||
@ -281,9 +284,9 @@ class LoginView(View, LogoutMixin):
|
||||
renewed = False
|
||||
warned = False
|
||||
|
||||
if settings.CAS_FEDERATE:
|
||||
username = None
|
||||
ticket = None
|
||||
# used if CAS_FEDERATE is True
|
||||
username = None
|
||||
ticket = None
|
||||
|
||||
INVALID_LOGIN_TICKET = 1
|
||||
USER_LOGIN_OK = 2
|
||||
@ -354,7 +357,7 @@ class LoginView(View, LogoutMixin):
|
||||
elif ret == self.USER_LOGIN_FAILURE: # bad user login
|
||||
if settings.CAS_FEDERATE:
|
||||
self.ticket = None
|
||||
self.usernalme = None
|
||||
self.username = None
|
||||
self.init_form()
|
||||
self.logout()
|
||||
elif ret == self.USER_ALREADY_LOGGED:
|
||||
@ -682,11 +685,14 @@ class Auth(View):
|
||||
secret = request.POST.get('secret')
|
||||
|
||||
if not settings.CAS_AUTH_SHARED_SECRET:
|
||||
return HttpResponse("no\nplease set CAS_AUTH_SHARED_SECRET", content_type="text/plain")
|
||||
return HttpResponse(
|
||||
"no\nplease set CAS_AUTH_SHARED_SECRET",
|
||||
content_type="text/plain; charset=utf-8"
|
||||
)
|
||||
if secret != settings.CAS_AUTH_SHARED_SECRET:
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||
if not username or not password or not service:
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||
form = forms.UserCredential(
|
||||
request.POST,
|
||||
initial={
|
||||
@ -714,11 +720,11 @@ class Auth(View):
|
||||
service_pattern.check_user(user)
|
||||
if not request.session.get("authenticated"):
|
||||
user.delete()
|
||||
return HttpResponse("yes\n", content_type="text/plain")
|
||||
return HttpResponse(u"yes\n", content_type="text/plain; charset=utf-8")
|
||||
except (ServicePattern.DoesNotExist, models.ServicePatternException):
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||
else:
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||
|
||||
|
||||
class Validate(View):
|
||||
@ -758,7 +764,10 @@ class Validate(View):
|
||||
username = username[0]
|
||||
else:
|
||||
username = ticket.user.username
|
||||
return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
|
||||
return HttpResponse(
|
||||
u"yes\n%s\n" % username,
|
||||
content_type="text/plain; charset=utf-8"
|
||||
)
|
||||
except ServiceTicket.DoesNotExist:
|
||||
logger.warning(
|
||||
(
|
||||
@ -769,10 +778,10 @@ class Validate(View):
|
||||
service
|
||||
)
|
||||
)
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||
else:
|
||||
logger.warning("Validate: service or ticket missing")
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||
|
||||
|
||||
class ValidateError(Exception):
|
||||
@ -815,8 +824,8 @@ class ValidateService(View, AttributesMixin):
|
||||
if not self.service or not self.ticket:
|
||||
logger.warning("ValidateService: missing ticket or service")
|
||||
return ValidateError(
|
||||
'INVALID_REQUEST',
|
||||
"you must specify a service and a ticket"
|
||||
u'INVALID_REQUEST',
|
||||
u"you must specify a service and a ticket"
|
||||
).render(request)
|
||||
else:
|
||||
try:
|
||||
@ -886,14 +895,14 @@ class ValidateService(View, AttributesMixin):
|
||||
for prox in ticket.proxies.all():
|
||||
proxies.append(prox.url)
|
||||
else:
|
||||
raise ValidateError('INVALID_TICKET', self.ticket)
|
||||
raise ValidateError(u'INVALID_TICKET', self.ticket)
|
||||
ticket.validate = True
|
||||
ticket.save()
|
||||
if ticket.service != self.service:
|
||||
raise ValidateError('INVALID_SERVICE', self.service)
|
||||
raise ValidateError(u'INVALID_SERVICE', self.service)
|
||||
return ticket, proxies
|
||||
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
|
||||
raise ValidateError('INVALID_TICKET', 'ticket not found')
|
||||
raise ValidateError(u'INVALID_TICKET', 'ticket not found')
|
||||
|
||||
def process_pgturl(self, params):
|
||||
"""Handle PGT request"""
|
||||
@ -939,18 +948,18 @@ class ValidateService(View, AttributesMixin):
|
||||
except requests.exceptions.RequestException as error:
|
||||
error = utils.unpack_nested_exception(error)
|
||||
raise ValidateError(
|
||||
'INVALID_PROXY_CALLBACK',
|
||||
"%s: %s" % (type(error), str(error))
|
||||
u'INVALID_PROXY_CALLBACK',
|
||||
u"%s: %s" % (type(error), str(error))
|
||||
)
|
||||
else:
|
||||
raise ValidateError(
|
||||
'INVALID_PROXY_CALLBACK',
|
||||
"callback url not allowed by configuration"
|
||||
u'INVALID_PROXY_CALLBACK',
|
||||
u"callback url not allowed by configuration"
|
||||
)
|
||||
except ServicePattern.DoesNotExist:
|
||||
raise ValidateError(
|
||||
'INVALID_PROXY_CALLBACK',
|
||||
'callback url not allowed by configuration'
|
||||
u'INVALID_PROXY_CALLBACK',
|
||||
u'callback url not allowed by configuration'
|
||||
)
|
||||
|
||||
|
||||
@ -971,8 +980,8 @@ class Proxy(View):
|
||||
return self.process_proxy()
|
||||
else:
|
||||
raise ValidateError(
|
||||
'INVALID_REQUEST',
|
||||
"you must specify and pgt and targetService"
|
||||
u'INVALID_REQUEST',
|
||||
u"you must specify and pgt and targetService"
|
||||
)
|
||||
except ValidateError as error:
|
||||
logger.warning("Proxy: validation error: %s %s" % (error.code, error.msg))
|
||||
@ -985,8 +994,8 @@ class Proxy(View):
|
||||
pattern = ServicePattern.validate(self.target_service)
|
||||
if not pattern.proxy:
|
||||
raise ValidateError(
|
||||
'UNAUTHORIZED_SERVICE',
|
||||
'the service %s do not allow proxy ticket' % self.target_service
|
||||
u'UNAUTHORIZED_SERVICE',
|
||||
u'the service %s do not allow proxy ticket' % self.target_service
|
||||
)
|
||||
# is the proxy granting ticket valid
|
||||
ticket = ProxyGrantingTicket.objects.get(
|
||||
@ -1015,13 +1024,13 @@ class Proxy(View):
|
||||
content_type="text/xml; charset=utf-8"
|
||||
)
|
||||
except ProxyGrantingTicket.DoesNotExist:
|
||||
raise ValidateError('INVALID_TICKET', 'PGT %s not found' % self.pgt)
|
||||
raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt)
|
||||
except ServicePattern.DoesNotExist:
|
||||
raise ValidateError('UNAUTHORIZED_SERVICE', self.target_service)
|
||||
raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service)
|
||||
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
|
||||
raise ValidateError(
|
||||
'UNAUTHORIZED_USER',
|
||||
'User %s not allowed on %s' % (ticket.user.username, self.target_service)
|
||||
u'UNAUTHORIZED_USER',
|
||||
u'User %s not allowed on %s' % (ticket.user.username, self.target_service)
|
||||
)
|
||||
|
||||
|
||||
@ -1129,18 +1138,18 @@ class SamlValidate(View, AttributesMixin):
|
||||
)
|
||||
else:
|
||||
raise SamlValidateError(
|
||||
'AuthnFailed',
|
||||
'ticket %s should begin with PT- or ST-' % ticket
|
||||
u'AuthnFailed',
|
||||
u'ticket %s should begin with PT- or ST-' % ticket
|
||||
)
|
||||
ticket.validate = True
|
||||
ticket.save()
|
||||
if ticket.service != self.target:
|
||||
raise SamlValidateError(
|
||||
'AuthnFailed',
|
||||
'TARGET %s do not match ticket service' % self.target
|
||||
u'AuthnFailed',
|
||||
u'TARGET %s do not match ticket service' % self.target
|
||||
)
|
||||
return ticket
|
||||
except (IndexError, KeyError):
|
||||
raise SamlValidateError('VersionMismatch')
|
||||
raise SamlValidateError(u'VersionMismatch')
|
||||
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
|
||||
raise SamlValidateError('AuthnFailed', 'ticket %s not found' % ticket)
|
||||
raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)
|
||||
|
Reference in New Issue
Block a user