Restructurate GTFS feeds into dedicated models

This commit is contained in:
2024-05-09 19:28:19 +02:00
parent 820fc0cc19
commit 11949228ee
11 changed files with 1594 additions and 950 deletions

View File

@ -5,8 +5,8 @@ import requests
from django.core.management import BaseCommand
from django.db.models import Q
from sncfgtfs.gtfs_realtime_pb2 import FeedMessage
from sncfgtfs.models import Agency, Calendar, CalendarDate, ExceptionType, LocationType, PickupType, \
from sncfgtfs.gtfs_realtime_pb2 import FeedMessage, TripUpdate as GTFSTripUpdate
from sncfgtfs.models import Agency, Calendar, CalendarDate, ExceptionType, GTFSFeed, LocationType, PickupType, \
Route, RouteType, Stop, StopScheduleRelationship, StopTime, StopTimeUpdate, \
Trip, TripUpdate, TripScheduleRelationship
@ -14,34 +14,33 @@ from sncfgtfs.models import Agency, Calendar, CalendarDate, ExceptionType, Locat
class Command(BaseCommand):
help = "Update the SNCF GTFS Realtime database."
GTFS_RT_FEEDS = {
"TGV": "https://proxy.transport.data.gouv.fr/resource/sncf-tgv-gtfs-rt-trip-updates",
"IC": "https://proxy.transport.data.gouv.fr/resource/sncf-ic-gtfs-rt-trip-updates",
"TER": "https://proxy.transport.data.gouv.fr/resource/sncf-ter-gtfs-rt-trip-updates",
"TI": "https://thello.axelor.com/public/gtfs/GTFS-RT.bin",
}
def add_arguments(self, parser):
parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode")
def handle(self, debug=False, *args, **options):
for feed_type, feed_url in self.GTFS_RT_FEEDS.items():
self.stdout.write(f"Updating {feed_type} feed...")
def handle(self, debug: bool = False, verbosity: int = 1, *args, **options):
for gtfs_feed in GTFSFeed.objects.all():
if not gtfs_feed.rt_feed_url:
if verbosity >= 2:
self.stdout.write(self.style.WARNING(f"No GTFS-RT feed found for {gtfs_feed}."))
continue
self.stdout.write(f"Updating GTFS-RT feed for {gtfs_feed}")
gtfs_code = gtfs_feed.code
feed_message = FeedMessage()
feed_message.ParseFromString(requests.get(feed_url).content)
feed_message.ParseFromString(requests.get(gtfs_feed.rt_feed_url, allow_redirects=True).content)
stop_times_updates = []
if debug:
with open(f'feed_message-{feed_type}.txt', 'w') as f:
with open(f'feed_message-{gtfs_code}.txt', 'w') as f:
f.write(str(feed_message))
for entity in feed_message.entity:
if entity.HasField("trip_update"):
trip_update = entity.trip_update
trip_id = trip_update.trip.trip_id
if feed_type in ["TGV", "IC", "TER"]:
trip_id = trip_id.split(":", 1)[0]
trip_id = f"{gtfs_code}-{trip_id}"
start_date = date(year=int(trip_update.trip.start_date[:4]),
month=int(trip_update.trip.start_date[4:6]),
@ -50,7 +49,7 @@ class Command(BaseCommand):
if trip_update.trip.schedule_relationship == TripScheduleRelationship.ADDED:
# C'est un trajet nouveau. On crée le trajet associé.
self.create_trip(trip_update, trip_id, start_dt, feed_type)
self.create_trip(trip_update, trip_id, start_dt, gtfs_feed)
if not Trip.objects.filter(id=trip_id).exists():
self.stdout.write(f"Trip {trip_id} does not exist in the GTFS feed.")
@ -68,22 +67,19 @@ class Command(BaseCommand):
for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update):
stop_id = stop_time_update.stop_id
if stop_id.startswith('StopArea:'):
# On est dans le cadre d'une gare. On cherche le quai associé.
if StopTime.objects.filter(trip_id=trip_id, stop__parent_station_id=stop_id).exists():
# U
stop = StopTime.objects.get(trip_id=trip_id, stop__parent_station_id=stop_id).stop
stop_id = f"{gtfs_code}-{stop_id}"
if StopTime.objects.filter(trip_id=trip_id, stop=stop_id).exists():
st = StopTime.objects.filter(trip_id=trip_id, stop=stop_id)
if st.count() > 1:
st = st.get(stop_sequence=stop_sequence)
else:
stops = [s for s in Stop.objects.filter(parent_station_id=stop_id).all()
for s2 in StopTime.objects.filter(trip_id=trip_id).all()
if s.stop_type in s2.stop.stop_type
or s2.stop.stop_type in s.stop_type]
stop = stops[0] if stops else Stop.objects.get(id=stop_id)
st, _created = StopTime.objects.update_or_create(
id=f"{trip_id}-{stop.id}",
st = st.first()
else:
# Stop is added
st = StopTime.objects.create(
id=f"{trip_id}-{stop_time_update.stop_id}",
trip_id=trip_id,
stop_id=stop.id,
stop_id=stop_id,
defaults={
"stop_sequence": stop_sequence,
"arrival_time": datetime.fromtimestamp(stop_time_update.arrival.time,
@ -96,23 +92,16 @@ class Command(BaseCommand):
else PickupType.NONE),
}
)
elif stop_time_update.schedule_relationship == StopScheduleRelationship.SKIPPED:
st = StopTime.objects.get(Q(stop=stop_id) | Q(stop__parent_station_id=stop_id),
trip_id=trip_id)
if stop_time_update.schedule_relationship == StopScheduleRelationship.SKIPPED:
if st.pickup_type != PickupType.NONE or st.drop_off_type != PickupType.NONE:
st.pickup_type = PickupType.NONE
st.drop_off_type = PickupType.NONE
st.save()
else:
qs = StopTime.objects.filter(Q(stop=stop_id) | Q(stop__parent_station_id=stop_id),
trip_id=trip_id)
if qs.count() == 1:
st = qs.first()
else:
st = qs.get(stop_sequence=stop_sequence)
if st.stop_sequence != stop_sequence:
st.stop_sequence = stop_sequence
st.save()
if st.stop_sequence != stop_sequence:
st.stop_sequence = stop_sequence
st.save()
st_update = StopTimeUpdate(
trip_update=tu,
@ -136,73 +125,22 @@ class Command(BaseCommand):
'departure_delay', 'departure_time'],
unique_fields=['trip_update', 'stop_time'])
def create_trip(self, trip_update, trip_id, start_dt, feed_type):
def create_trip(self, trip_update: GTFSTripUpdate, trip_id: str, start_dt: datetime, gtfs_feed: GTFSFeed) -> None:
headsign = trip_id[5:-1]
trip_qs = Trip.objects.all()
trip_ids = trip_qs.values_list('id', flat=True)
gtfs_code = gtfs_feed.code
first_stop_queryset = StopTime.objects.filter(
stop__parent_station_id=trip_update.stop_time_update[0].stop_id,
).values('trip_id')
last_stop_queryset = StopTime.objects.filter(
stop__parent_station_id=trip_update.stop_time_update[-1].stop_id,
).values('trip_id')
trip_ids = trip_ids.intersection(first_stop_queryset).intersection(last_stop_queryset)
# print(trip_id, trip_ids)
for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update):
stop_id = stop_time_update.stop_id
st_queryset = StopTime.objects.filter(stop__parent_station_id=stop_id)
if stop_sequence == 0:
st_queryset = st_queryset.filter(stop_sequence=0)
# print(stop_sequence, Stop.objects.get(id=stop_id).name, stop_time_update)
# print(trip_ids)
# print(st_queryset.values('trip_id').all())
trip_ids_restrict = trip_ids.intersection(st_queryset.values('trip_id'))
if trip_ids_restrict:
trip_ids = trip_ids_restrict
else:
stop = Stop.objects.get(id=stop_id)
self.stdout.write(self.style.WARNING(f"Warning: No trip is found passing by stop "
f"{stop.name} ({stop_id})"))
trip_ids = set(trip_ids)
route_ids = set(Trip.objects.filter(id__in=trip_ids).values_list('route_id', flat=True))
self.stdout.write(f"{len(route_ids)} routes found on trip for new train {headsign}")
if not route_ids:
origin_id = trip_update.stop_time_update[0].stop_id
origin = Stop.objects.get(id=origin_id)
destination_id = trip_update.stop_time_update[-1].stop_id
destination = Stop.objects.get(id=destination_id)
trip_name = f"{origin.name} - {destination.name}"
trip_reverse_name = f"{destination.name} - {origin.name}"
route_qs = Route.objects.filter(long_name=trip_name, transport_type=feed_type)
route_reverse_qs = Route.objects.filter(long_name=trip_reverse_name,
transport_type=feed_type)
if route_qs.exists():
route_ids = set(route_qs.values_list('id', flat=True))
elif route_reverse_qs.exists():
route_ids = set(route_reverse_qs.values_list('id', flat=True))
else:
self.stdout.write(f"Route not found for trip {trip_id} ({trip_name}). Creating new one")
route = Route.objects.create(
id=f"CREATED-{trip_name}",
agency=Agency.objects.filter(routes__transport_type=feed_type).first(),
transport_type=feed_type,
type=RouteType.RAIL,
short_name=trip_name,
long_name=trip_name,
)
route_ids = {route.id}
self.stdout.write(f"Route {route.id} created for trip {trip_id} ({trip_name})")
elif len(route_ids) > 1:
self.stdout.write(f"Multiple routes found for trip {trip_id}.")
self.stdout.write(", ".join(route_ids))
route_id = route_ids.pop()
route, _created = Route.objects.get_or_create(
id=f"{gtfs_code}-ADDED-{headsign}",
gtfs_feed=gtfs_feed,
type=RouteType.RAIL,
short_name="ADDED",
long_name="ADDED ROUTE",
)
Calendar.objects.update_or_create(
id=f"{feed_type}-new-{headsign}",
id=f"{gtfs_code}-ADDED-{headsign}",
defaults={
"transport_type": feed_type,
"gtfs_feed": gtfs_feed,
"monday": False,
"tuesday": False,
"wednesday": False,
@ -215,9 +153,9 @@ class Command(BaseCommand):
}
)
CalendarDate.objects.update_or_create(
id=f"{feed_type}-{headsign}-{trip_update.trip.start_date}",
id=f"{gtfs_code}-ADDED-{headsign}-{trip_update.trip.start_date}",
defaults={
"service_id": f"{feed_type}-new-{headsign}",
"service_id": f"{gtfs_code}-ADDED-{headsign}",
"date": trip_update.trip.start_date,
"exception_type": ExceptionType.ADDED,
}
@ -225,32 +163,17 @@ class Command(BaseCommand):
Trip.objects.update_or_create(
id=trip_id,
defaults={
"route_id": route_id,
"service_id": f"{feed_type}-new-{headsign}",
"route_id": route.id,
"service_id": f"{gtfs_code}-ADDED-{headsign}",
"headsign": headsign,
"direction_id": trip_update.trip.direction_id,
"gtfs_feed": gtfs_feed,
}
)
sample_trip = Trip.objects.filter(id__in=trip_ids, route_id=route_id)
sample_trip = sample_trip.first() if sample_trip.exists() else None
for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update):
stop_id = stop_time_update.stop_id
stop = Stop.objects.get(id=stop_id)
if stop.location_type == LocationType.STATION:
if not StopTime.objects.filter(trip_id=trip_id).exists():
if sample_trip:
stop = StopTime.objects.get(trip_id=sample_trip.id,
stop__parent_station_id=stop_id).stop
elif StopTime.objects.filter(trip_id=trip_id, stop__parent_station_id=stop_id).exists():
stop = StopTime.objects.get(trip_id=trip_id, stop__parent_station_id=stop_id).stop
else:
stops = [s for s in Stop.objects.filter(parent_station_id=stop_id).all()
for s2 in StopTime.objects.filter(trip_id=trip_id).all()
if s.stop_type in s2.stop.stop_type
or s2.stop.stop_type in s.stop_type]
stop = stops[0] if stops else stop
stop_id = stop.id
stop_id = f"{gtfs_code}-{stop_id}"
arr_time = datetime.fromtimestamp(stop_time_update.arrival.time,
tz=ZoneInfo("Europe/Paris")) - start_dt
@ -263,7 +186,7 @@ class Command(BaseCommand):
and stop_sequence < len(trip_update.stop_time_update) - 1 else PickupType.NONE
StopTime.objects.update_or_create(
id=f"{trip_id}-{stop_id}",
id=f"{trip_id}-{stop_time_update.stop_id}",
trip_id=trip_id,
defaults={
"stop_id": stop_id,