add some tests

This commit is contained in:
Valentin Samir
2015-06-21 18:56:16 +02:00
parent c0d8550120
commit 50781dba18
13 changed files with 195 additions and 78 deletions

View File

@ -1,3 +1,4 @@
import functools
from cas_server import models
class DummyUserManager(object):
@ -10,6 +11,75 @@ class DummyUserManager(object):
else:
raise models.User.DoesNotExist()
def dummy(*args, **kwds):
pass
def dummy_service_pattern(**kwargs):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
service_validate = models.ServicePattern.validate
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs))
ret = func(*args, **kwds)
models.ServicePattern.validate = service_validate
return ret
return wrapper
return decorator
def dummy_user(username, session_key):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
user_manager = models.User.objects
user_save = models.User.save
user_delete = models.User.delete
models.User.objects = DummyUserManager(username, session_key)
models.User.save = dummy
models.User.delete = dummy
ret = func(*args, **kwds)
models.User.objects = user_manager
models.User.save = user_save
models.User.delete = user_delete
return ret
return wrapper
return decorator
def dummy_ticket(ticket_class, service, ticket):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
ticket_manager = ticket_class.objects
ticket_save = ticket_class.save
ticket_delete = ticket_class.delete
ticket_class.objects = DummyTicketManager(ticket_class, service, ticket)
ticket_class.save = dummy
ticket_class.delete = dummy
ret = func(*args, **kwds)
ticket_class.objects = ticket_manager
ticket_class.save = ticket_save
ticket_class.delete = ticket_delete
return ret
return wrapper
return decorator
def dummy_proxy(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
proxy_manager = models.Proxy.objects
models.Proxy.objects = DummyProxyManager()
ret = func(*args, **kwds)
models.Proxy.objects = proxy_manager
return ret
return wrapper
class DummyProxyManager(object):
def create(self, **kwargs):
for field in models.Proxy._meta.fields:
field.allow_unsaved_instance_assignment = True
return models.Proxy(**kwargs)
class DummyTicketManager(object):
def __init__(self, ticket_class, service, ticket):
self.ticket_class = ticket_class
@ -17,7 +87,7 @@ class DummyTicketManager(object):
self.ticket = ticket
def create(self, **kwargs):
for field in models.ServiceTicket._meta.fields:
for field in self.ticket_class._meta.fields:
field.allow_unsaved_instance_assignment = True
return self.ticket_class(**kwargs)
@ -25,6 +95,8 @@ class DummyTicketManager(object):
return DummyQuerySet()
def get(self, **kwargs):
for field in self.ticket_class._meta.fields:
field.allow_unsaved_instance_assignment = True
if 'value' in kwargs:
if kwargs['value'] != self.ticket:
raise self.ticket_class.DoesNotExist()
@ -41,7 +113,7 @@ class DummyTicketManager(object):
for field in models.ServiceTicket._meta.fields:
field.allow_unsaved_instance_assignment = True
for key in kwargs.keys():
for key in list(kwargs):
if '__' in key:
del kwargs[key]
kwargs['attributs'] = {'mail': 'test@example.com'}

52
tests/test_proxy.py Normal file
View File

@ -0,0 +1,52 @@
from __future__ import absolute_import
from tests.init import *
from django.test import RequestFactory
import os
import pytest
from lxml import etree
from cas_server.views import ValidateService, Proxy
from cas_server import models
from tests.dummy import *
@pytest.mark.django_db
@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random")
@dummy_service_pattern(proxy=True)
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random")
@dummy_proxy
def test_proxy_ok():
factory = RequestFactory()
request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com')
request.session = DummySession()
proxy = Proxy()
response = proxy.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(proxy_tickets) == 1
factory = RequestFactory()
request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com')
validate = ValidateService()
validate.allow_proxy_ticket = True
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(users) == 1
assert users[0].text == "test"

View File

@ -12,15 +12,13 @@ from cas_server import models
from .dummy import *
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_service_view_ok():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServiceTicket.save = lambda x:None
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)
@ -47,15 +45,13 @@ def test_validate_service_view_ok():
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random")
def test_validate_service_view_badservice():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example2.com', "ST-random")
models.ServiceTicket.save = lambda x:None
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)
@ -70,15 +66,13 @@ def test_validate_service_view_badservice():
assert error[0].attrib['code'] == 'INVALID_SERVICE'
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2")
def test_validate_service_view_badticket():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random2")
models.ServiceTicket.save = lambda x:None
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)

View File

@ -14,36 +14,33 @@ from .dummy import *
settings.CAS_AUTH_SHARED_SECRET = "test"
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
@dummy_user(username="test", session_key="test_session")
@dummy_service_pattern()
def test_auth_view_goodpass():
factory = RequestFactory()
request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
request.session = DummySession()
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
auth = Auth()
response = auth.post(request)
assert response.status_code == 200
assert response.content == "yes\n"
assert response.content == b"yes\n"
@dummy_service_pattern()
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
@dummy_user(username="test", session_key="test_session")
def test_auth_view_badpass():
factory = RequestFactory()
request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
request.session = DummySession()
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
auth = Auth()
response = auth.post(request)
assert response.status_code == 200
assert response.content == "no\n"
assert response.content == b"no\n"

View File

@ -92,6 +92,7 @@ def test_view_login_get_unauth():
assert response.status_code == 200
@pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_view_login_get_auth():
factory = RequestFactory()
request = factory.post('/login')
@ -107,14 +108,15 @@ def test_view_login_get_auth():
assert ret == LoginView.USER_AUTHENTICATED
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
login = LoginView()
response = login.get(request)
assert response.status_code == 200
@pytest.mark.django_db
@dummy_service_pattern()
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_view_login_get_auth_service():
factory = RequestFactory()
request = factory.post('/login?service=https://www.example.com')
@ -130,12 +132,6 @@ def test_view_login_get_auth_service():
assert ret == LoginView.USER_AUTHENTICATED
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.User.save = lambda x:None
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
models.ServiceTicket.save = lambda x:None
login = LoginView()
response = login.get(request)
@ -143,6 +139,9 @@ def test_view_login_get_auth_service():
assert response['Location'].startswith('https://www.example.com?ticket=ST-')
@pytest.mark.django_db
@dummy_service_pattern()
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_view_login_get_auth_service_warn():
factory = RequestFactory()
request = factory.post('/login?service=https://www.example.com')
@ -158,12 +157,6 @@ def test_view_login_get_auth_service_warn():
assert ret == LoginView.USER_AUTHENTICATED
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
models.User.save = lambda x:None
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
models.ServiceTicket.save = lambda x:None
login = LoginView()
response = login.get(request)

View File

@ -13,6 +13,7 @@ from .dummy import *
@pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view():
factory = RequestFactory()
request = factory.get('/logout')
@ -23,21 +24,17 @@ def test_logout_view():
request.session["username"] = "test"
request.session["warn"] = False
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
dlist = [None]
models.User.delete = lambda x:dlist.pop()
logout = LogoutView()
response = logout.get(request)
assert response.status_code == 200
assert dlist == []
assert not request.session.get("authenticated")
assert not request.session.get("username")
assert not request.session.get("warn")
@pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view_url():
factory = RequestFactory()
request = factory.get('/logout?url=https://www.example.com')
@ -48,16 +45,11 @@ def test_logout_view_url():
request.session["username"] = "test"
request.session["warn"] = False
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
dlist = [None]
models.User.delete = lambda x:dlist.pop()
logout = LogoutView()
response = logout.get(request)
assert response.status_code == 302
assert response['Location'] == 'https://www.example.com'
assert dlist == []
assert not request.session.get("authenticated")
assert not request.session.get("username")
assert not request.session.get("warn")
@ -65,6 +57,7 @@ def test_logout_view_url():
@pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view_service():
factory = RequestFactory()
request = factory.get('/logout?service=https://www.example.com')
@ -75,16 +68,11 @@ def test_logout_view_service():
request.session["username"] = "test"
request.session["warn"] = False
models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
dlist = [None]
models.User.delete = lambda x:dlist.pop()
logout = LogoutView()
response = logout.get(request)
assert response.status_code == 302
assert response['Location'] == 'https://www.example.com'
assert dlist == []
assert not request.session.get("authenticated")
assert not request.session.get("username")
assert not request.session.get("warn")

View File

@ -12,50 +12,47 @@ from cas_server import models
from .dummy import *
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_view_ok():
factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
validate = Validate()
response = validate.get(request)
assert response.status_code == 200
assert response.content == "yes\n"
assert response.content == b"yes\n"
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_view_badservice():
factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
validate = Validate()
response = validate.get(request)
assert response.status_code == 200
assert response.content == "no\n"
assert response.content == b"no\n"
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1")
def test_validate_view_badticket():
factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
request.session = DummySession()
models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random1")
validate = Validate()
response = validate.get(request)
assert response.status_code == 200
assert response.content == "no\n"
assert response.content == b"no\n"