Display trains that are near a station

This commit is contained in:
2024-05-11 20:52:22 +02:00
parent 735191947d
commit 070849c427
12 changed files with 175 additions and 165 deletions

View File

@ -1,9 +1,18 @@
from rest_framework import serializers
from trainvel.core.models import Station
from trainvel.gtfs.models import Agency, Stop, Route, Trip, StopTime, Calendar, CalendarDate, \
Transfer, FeedInfo, TripUpdate, StopTimeUpdate
class StationSerializer(serializers.ModelSerializer):
class Meta:
model = Station
lookup_field = 'slug'
fields = ('id', 'slug', 'name', 'uic', 'uic8_sncf', 'latitude', 'longitude', 'country',
'country_hint', 'main_station_hint',)
class AgencySerializer(serializers.ModelSerializer):
class Meta:
model = Agency

View File

@ -10,13 +10,28 @@ from rest_framework.filters import OrderingFilter, SearchFilter
from trainvel.api.serializers import AgencySerializer, StopSerializer, RouteSerializer, TripSerializer, \
StopTimeSerializer, CalendarSerializer, CalendarDateSerializer, TransferSerializer, \
FeedInfoSerializer, TripUpdateSerializer, StopTimeUpdateSerializer
from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, StopTimeUpdate, \
Transfer, Trip, TripUpdate
FeedInfoSerializer, TripUpdateSerializer, StopTimeUpdateSerializer, StationSerializer
from trainvel.core.models import Station
from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \
StopTimeUpdate, \
Transfer, Trip, TripUpdate, PickupType
CACHE_CONTROL = cache_control(max_age=7200)
CACHE_CONTROL = cache_control(max_age=30)
LAST_MODIFIED = last_modified(lambda *args, **kwargs: GTFSFeed.objects.order_by('-last_modified').first().last_modified)
LOOKUP_VALUE_REGEX = r"[\w.: |-]+"
LOOKUP_VALUE_REGEX = r"[\w.: |+-]+"
class StationViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Station.objects.filter(is_suggestable=True)
serializer_class = StationSerializer
filter_backends = [DjangoFilterBackend, SearchFilter]
filterset_fields = '__all__'
search_fields = ['name', 'slug',
'info_de', 'info_en', 'info_es', 'info_fr', 'info_it', 'info_nb', 'info_nl', 'info_cs',
'info_da', 'info_hu', 'info_ja', 'info_ko', 'info_pl', 'info_pt', 'info_ru', 'info_sv',
'info_tr', 'info_zh', ]
lookup_field = 'slug'
lookup_value_regex = LOOKUP_VALUE_REGEX
@method_decorator(name='list', decorator=[CACHE_CONTROL, LAST_MODIFIED])
@ -135,8 +150,7 @@ class NextDeparturesViewSet(viewsets.ReadOnlyModelViewSet):
def get_queryset(self):
now = datetime.now()
stop_id = self.request.query_params.get('stop_id', None)
stop_name = self.request.query_params.get('stop_name', None)
station_slug = self.request.query_params.get('station_slug', None)
query_date = date.fromisoformat(self.request.query_params.get('date', now.date().isoformat()))
query_time = self.request.query_params.get('time', now.time().isoformat(timespec='seconds'))
query_time = timedelta(seconds=int(query_time[:2]) * 3600
@ -148,16 +162,10 @@ class NextDeparturesViewSet(viewsets.ReadOnlyModelViewSet):
tomorrow = query_date + timedelta(days=1)
stop_filter = Q(stop__location_type=0)
if stop_id:
stop = Stop.objects.get(id=stop_id)
stops = Stop.objects.filter(Q(id=stop_id)
| Q(parent_station=stop_id))
if stop.location_type == 0 and stop.parent_station_id is not None:
stops |= Stop.objects.filter(parent_station=stop.parent_station_id)
stop_filter = Q(stop__in=stops.values_list('id', flat=True))
elif stop_name:
stops = Stop.objects.filter(name__iexact=stop_name).values_list('id', flat=True)
stop_filter = Q(stop__in=stops)
if station_slug:
station = Station.objects.get(is_suggestable=True, slug=station_slug)
near_stops = station.get_near_stops()
stop_filter = Q(stop_id__in=near_stops.values_list('id', flat=True))
def calendar_filter(d: date):
return Q(trip__service_id__in=CalendarDate.objects.filter(date=d, exception_type=1)
@ -189,7 +197,7 @@ class NextDeparturesViewSet(viewsets.ReadOnlyModelViewSet):
qs_today = StopTime.objects.filter(stop_filter) \
.annotate(departure_time_real=departure_time_real(query_date)) \
.filter(departure_time_real__gte=query_time) \
.filter(Q(pickup_type=0) | canceled_filter(query_date)) \
.filter(Q(pickup_type=PickupType.REGULAR) | canceled_filter(query_date)) \
.filter(calendar_filter(query_date)) \
.annotate(departure_date=Value(query_date)) \
.annotate(departure_time_24h=F('departure_time'))
@ -221,8 +229,7 @@ class NextArrivalsViewSet(viewsets.ReadOnlyModelViewSet):
def get_queryset(self):
now = datetime.now()
stop_id = self.request.query_params.get('stop_id', None)
stop_name = self.request.query_params.get('stop_name', None)
station_slug = self.request.query_params.get('station_slug', None)
query_date = date.fromisoformat(self.request.query_params.get('date', now.date().isoformat()))
query_time = self.request.query_params.get('time', now.time().isoformat(timespec='seconds'))
query_time = timedelta(seconds=int(query_time[:2]) * 3600
@ -235,16 +242,10 @@ class NextArrivalsViewSet(viewsets.ReadOnlyModelViewSet):
tomorrow = query_date + timedelta(days=1)
stop_filter = Q(stop__location_type=0)
if stop_id:
stop = Stop.objects.get(id=stop_id)
stops = Stop.objects.filter(Q(id=stop_id)
| Q(parent_station=stop_id))
if stop.location_type == 0 and stop.parent_station_id is not None:
stops |= Stop.objects.filter(parent_station=stop.parent_station_id)
stop_filter = Q(stop__in=stops.values_list('id', flat=True))
elif stop_name:
stops = Stop.objects.filter(name__iexact=stop_name).values_list('id', flat=True)
stop_filter = Q(stop__in=stops)
if station_slug:
station = Station.objects.get(is_suggestable=True, slug=station_slug)
near_stops = station.get_near_stops()
stop_filter = Q(stop_id__in=near_stops.values_list('id', flat=True))
def calendar_filter(d: date):
return Q(trip__service_id__in=CalendarDate.objects.filter(date=d, exception_type=1)

View File

@ -1,7 +1,10 @@
from django.conf import settings
from django.db import models
from django.db.models import F, QuerySet
from django.db.models.functions import ACos, Sin, Radians, Cos
from django.utils.translation import gettext_lazy as _
from trainvel.gtfs.models import Country
from trainvel.gtfs.models import Country, Stop
class Station(models.Model):
@ -498,6 +501,16 @@ class Station(models.Model):
default=None,
)
def get_near_stops(self, radius: float = settings.STATION_RADIUS) -> QuerySet[Stop]:
"""
Returns a queryset of all stops that are in a radius of radius meters around the station.
It calculates a distance from each stop to the station using spatial coordinates.
"""
return Stop.objects.annotate(distance=6371000 * ACos(
Sin(Radians(self.latitude)) * Sin(Radians(F('lat')))
+ Cos(Radians(self.latitude)) * Cos(Radians(F('lat'))) * Cos(Radians(F('lon')) - Radians(self.longitude))))\
.filter(distance__lte=radius)
def __str__(self):
return self.name

View File

@ -294,14 +294,14 @@ class Command(BaseCommand):
dep_h, dep_m, dep_s = map(int, dep_time.split(':'))
dep_time = dep_h * 3600 + dep_m * 60 + dep_s
pickup_type = stop_time_dict.get('pickup_type', 0)
drop_off_type = stop_time_dict.get('drop_off_type', 0)
if stop_time_dict['stop_sequence'] == "1":
# First stop
drop_off_type = PickupType.NONE
elif arr_time == dep_time:
# Last stop
pickup_type = PickupType.NONE
pickup_type = stop_time_dict.get('pickup_type', PickupType.REGULAR)
drop_off_type = stop_time_dict.get('drop_off_type', PickupType.REGULAR)
# if stop_time_dict['stop_sequence'] == "1":
# # First stop
# drop_off_type = PickupType.NONE
# elif arr_time == dep_time:
# # Last stop
# pickup_type = PickupType.NONE
st = StopTime(
id=f"{gtfs_code}-{stop_time_dict['trip_id']}-{stop_time_dict['stop_id']}"
@ -349,7 +349,7 @@ class Command(BaseCommand):
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
transfer_type=transfer_dict['transfer_type'],
min_transfer_time=transfer_dict['min_transfer_time'],
min_transfer_time=transfer_dict.get('min_transfer_time', 0) or 0,
)
transfers.append(transfer)

View File

@ -30,7 +30,7 @@ class Command(BaseCommand):
headers = {}
if gtfs_code == "CH-ALL":
headers["Authorization"] = settings.OPENTRANSPORTDATA_SWISS_TOKEN
resp = requests.get(gtfs_feed.rt_feed_url, allow_redirects=True)
resp = requests.get(gtfs_feed.rt_feed_url, allow_redirects=True, headers=headers)
feed_message = FeedMessage()
feed_message.ParseFromString(resp.content)
@ -41,87 +41,88 @@ class Command(BaseCommand):
f.write(str(feed_message))
for entity in feed_message.entity:
if entity.HasField("trip_update"):
trip_update = entity.trip_update
trip_id = trip_update.trip.trip_id
trip_id = f"{gtfs_code}-{trip_id}"
try:
if entity.HasField("trip_update"):
trip_update = entity.trip_update
trip_id = trip_update.trip.trip_id
trip_id = f"{gtfs_code}-{trip_id}"
start_date = date(year=int(trip_update.trip.start_date[:4]),
month=int(trip_update.trip.start_date[4:6]),
day=int(trip_update.trip.start_date[6:]))
start_dt = datetime.combine(start_date, time(0), tzinfo=ZoneInfo("Europe/Paris"))
start_date = date(year=int(trip_update.trip.start_date[:4]),
month=int(trip_update.trip.start_date[4:6]),
day=int(trip_update.trip.start_date[6:]))
start_dt = datetime.combine(start_date, time(0), tzinfo=ZoneInfo("Europe/Paris"))
if trip_update.trip.schedule_relationship == TripScheduleRelationship.ADDED:
# C'est un trajet nouveau. On crée le trajet associé.
self.create_trip(trip_update, trip_id, start_dt, gtfs_feed)
if trip_update.trip.schedule_relationship == TripScheduleRelationship.ADDED:
# C'est un trajet nouveau. On crée le trajet associé.
self.create_trip(trip_update, trip_id, start_dt, gtfs_feed)
if not Trip.objects.filter(id=trip_id).exists():
self.stdout.write(f"Trip {trip_id} does not exist in the GTFS feed.")
continue
if not Trip.objects.filter(id=trip_id).exists():
self.stdout.write(f"Trip {trip_id} does not exist in the GTFS feed.")
continue
# Création du TripUpdate
tu, _created = TripUpdate.objects.update_or_create(
trip_id=trip_id,
start_date=trip_update.trip.start_date,
start_time=trip_update.trip.start_time,
defaults=dict(
schedule_relationship=trip_update.trip.schedule_relationship,
)
)
for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update):
stop_id = stop_time_update.stop_id
stop_id = f"{gtfs_code}-{stop_id}"
if StopTime.objects.filter(trip_id=trip_id, stop=stop_id).exists():
st = StopTime.objects.filter(trip_id=trip_id, stop=stop_id)
if st.count() > 1:
st = st.get(stop_sequence=stop_sequence)
else:
st = st.first()
else:
# Stop is added
st = StopTime.objects.create(
id=f"{trip_id}-{stop_time_update.stop_id}",
trip_id=trip_id,
stop_id=stop_id,
defaults={
"stop_sequence": stop_sequence,
"arrival_time": datetime.fromtimestamp(stop_time_update.arrival.time,
tz=ZoneInfo("Europe/Paris")) - start_dt,
"departure_time": datetime.fromtimestamp(stop_time_update.departure.time,
tz=ZoneInfo("Europe/Paris")) - start_dt,
"pickup_type": (PickupType.REGULAR if stop_time_update.departure.time
else PickupType.NONE),
"drop_off_type": (PickupType.REGULAR if stop_time_update.arrival.time
else PickupType.NONE),
}
# Création du TripUpdate
tu, _created = TripUpdate.objects.update_or_create(
trip_id=trip_id,
start_date=trip_update.trip.start_date,
start_time=trip_update.trip.start_time,
defaults=dict(
schedule_relationship=trip_update.trip.schedule_relationship,
)
)
if stop_time_update.schedule_relationship == StopScheduleRelationship.SKIPPED:
if st.pickup_type != PickupType.NONE or st.drop_off_type != PickupType.NONE:
st.pickup_type = PickupType.NONE
st.drop_off_type = PickupType.NONE
for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update):
stop_id = stop_time_update.stop_id
stop_id = f"{gtfs_code}-{stop_id}"
if StopTime.objects.filter(trip_id=trip_id, stop=stop_id).exists():
st = StopTime.objects.filter(trip_id=trip_id, stop=stop_id)
if st.count() > 1:
st = st.get(stop_sequence=stop_sequence)
else:
st = st.first()
else:
# Stop is added
st = StopTime.objects.create(
id=f"{trip_id}-{stop_time_update.stop_id}",
trip_id=trip_id,
stop_id=stop_id,
stop_sequence=stop_sequence,
arrival_time=datetime.fromtimestamp(stop_time_update.arrival.time,
tz=ZoneInfo("Europe/Paris")) - start_dt,
departure_time=datetime.fromtimestamp(stop_time_update.departure.time,
tz=ZoneInfo("Europe/Paris")) - start_dt,
pickup_type=(PickupType.REGULAR if stop_time_update.departure.time
else PickupType.NONE),
drop_off_type=(PickupType.REGULAR if stop_time_update.arrival.time
else PickupType.NONE),
)
if stop_time_update.schedule_relationship == StopScheduleRelationship.SKIPPED:
if st.pickup_type != PickupType.NONE or st.drop_off_type != PickupType.NONE:
st.pickup_type = PickupType.NONE
st.drop_off_type = PickupType.NONE
st.save()
if st.stop_sequence != stop_sequence:
st.stop_sequence = stop_sequence
st.save()
if st.stop_sequence != stop_sequence:
st.stop_sequence = stop_sequence
st.save()
st_update = StopTimeUpdate(
trip_update=tu,
stop_time=st,
arrival_delay=timedelta(seconds=stop_time_update.arrival.delay),
arrival_time=datetime.fromtimestamp(stop_time_update.arrival.time,
tz=ZoneInfo("Europe/Paris")),
departure_delay=timedelta(seconds=stop_time_update.departure.delay),
departure_time=datetime.fromtimestamp(stop_time_update.departure.time,
tz=ZoneInfo("Europe/Paris")),
schedule_relationship=stop_time_update.schedule_relationship
or StopScheduleRelationship.SCHEDULED,
)
stop_times_updates.append(st_update)
else:
self.stdout.write(str(entity))
st_update = StopTimeUpdate(
trip_update=tu,
stop_time=st,
arrival_delay=timedelta(seconds=stop_time_update.arrival.delay),
arrival_time=datetime.fromtimestamp(stop_time_update.arrival.time,
tz=ZoneInfo("Europe/Paris")),
departure_delay=timedelta(seconds=stop_time_update.departure.delay),
departure_time=datetime.fromtimestamp(stop_time_update.departure.time,
tz=ZoneInfo("Europe/Paris")),
schedule_relationship=stop_time_update.schedule_relationship
or StopScheduleRelationship.SCHEDULED,
)
stop_times_updates.append(st_update)
else:
self.stdout.write(str(entity))
except Exception as e:
self.stderr.write(self.style.ERROR(f"Error while processing entity: {e}"))
StopTimeUpdate.objects.bulk_create(stop_times_updates,
update_conflicts=True,

View File

@ -151,6 +151,8 @@ REST_FRAMEWORK = {
'PAGE_SIZE': 20,
}
STATION_RADIUS = 300
OPENTRANSPORTDATA_SWISS_TOKEN = "CHANGE ME"

View File

@ -18,11 +18,12 @@ from django.contrib import admin
from django.urls import path, include
from rest_framework import routers
from trainvel.api.views import AgencyViewSet, StopViewSet, RouteViewSet, TripViewSet, StopTimeViewSet, \
from trainvel.api.views import AgencyViewSet, StopViewSet, RouteViewSet, StationViewSet, TripViewSet, StopTimeViewSet, \
CalendarViewSet, CalendarDateViewSet, TransferViewSet, FeedInfoViewSet, NextDeparturesViewSet, NextArrivalsViewSet, \
TripUpdateViewSet, StopTimeUpdateViewSet
router = routers.DefaultRouter()
router.register("core/station", StationViewSet)
router.register("gtfs/agency", AgencyViewSet)
router.register("gtfs/stop", StopViewSet)
router.register("gtfs/route", RouteViewSet)