Restructurate GTFS feeds into dedicated models

This commit is contained in:
2024-05-09 19:28:19 +02:00
parent 820fc0cc19
commit 11949228ee
11 changed files with 1594 additions and 950 deletions

View File

@ -2,27 +2,18 @@ import csv
from datetime import datetime, timedelta
from io import BytesIO
from zipfile import ZipFile
from zoneinfo import ZoneInfo
import requests
from django.core.management import BaseCommand
from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, Route, Stop, StopTime, Transfer, Trip
from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, Transfer, Trip, \
PickupType
class Command(BaseCommand):
help = "Update the SNCF GTFS database."
GTFS_FEEDS = {
"TGV": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export_gtfs_voyages.zip",
"IC": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-intercites-gtfs-last.zip",
"TER": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip",
"TN": "https://eu.ftp.opendatasoft.com/sncf/gtfs/transilien-gtfs.zip",
"ES": "https://www.data.gouv.fr/fr/datasets/r/9089b550-696e-4ae0-87b5-40ea55a14292",
"TI": "https://thello.axelor.com/public/gtfs/gtfs.zip",
"RENFE": "https://ssl.renfe.com/gtransit/Fichero_AV_LD/google_transit.zip",
"OBB": "https://static.oebb.at/open-data/soll-fahrplan-gtfs/GTFS_OP_2024_obb.zip",
}
def add_arguments(self, parser):
parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode")
parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.")
@ -30,36 +21,36 @@ class Command(BaseCommand):
help="Do not update the database, only print what would be done.")
parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.")
def handle(self, *args, **options):
bulk_size = options['bulk_size']
dry_run = options['dry_run']
force = options['force']
def handle(self, debug: bool = False, bulk_size: int = 100, dry_run: bool = False, force: bool = False,
verbosity: int = 1, *args, **options):
if dry_run:
self.stdout.write(self.style.WARNING("Dry run mode activated."))
if not FeedInfo.objects.exists():
last_update_date = "1970-01-01"
else:
last_update_date = FeedInfo.objects.get(publisher_name='SNCF_default').version
for url in self.GTFS_FEEDS.values():
resp = requests.head(url)
if "Last-Modified" not in resp.headers:
continue
last_modified = resp.headers["Last-Modified"]
last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z")
if last_modified.date().isoformat() > last_update_date:
break
else:
if not force:
self.stdout.write(self.style.WARNING("Database already up-to-date."))
return
self.stdout.write("Updating database...")
for transport_type, feed_url in self.GTFS_FEEDS.items():
self.stdout.write(f"Downloading {transport_type} GTFS feed...")
with ZipFile(BytesIO(requests.get(feed_url).content)) as zipfile:
for gtfs_feed in GTFSFeed.objects.all():
gtfs_code = gtfs_feed.code
if not force:
# Check if the source file was updated
resp = requests.head(gtfs_feed.feed_url, allow_redirects=True)
if 'ETag' in resp.headers and gtfs_feed.etag:
if resp.headers['ETag'] == gtfs_feed.etag:
if verbosity >= 1:
self.stdout.write(self.style.WARNING(f"Database is already up-to-date for {gtfs_feed}."))
continue
if 'Last-Modified' in resp.headers and gtfs_feed.last_modified:
last_modified = resp.headers['Last-Modified']
last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
if last_modified <= gtfs_feed.last_modified:
if verbosity >= 1:
self.stdout.write(self.style.WARNING(f"Database is already up-to-date for {gtfs_feed}."))
continue
self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...")
resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True)
with ZipFile(BytesIO(resp.content)) as zipfile:
def read_file(filename):
lines = zipfile.read(filename).decode().replace('\ufeff', '').splitlines()
return [line.strip() for line in lines]
@ -67,23 +58,25 @@ class Command(BaseCommand):
agencies = []
for agency_dict in csv.DictReader(read_file("agency.txt")):
agency_dict: dict
if transport_type == "ES" \
and agency_dict['agency_id'] != 'ES' and agency_dict['agency_id'] != 'ER':
continue
# if gtfs_code == "FR-EUROSTAR" \
# and agency_dict['agency_id'] != 'ES' and agency_dict['agency_id'] != 'ER':
# continue
agency = Agency(
id=agency_dict['agency_id'],
id=f"{gtfs_code}-{agency_dict['agency_id']}",
name=agency_dict['agency_name'],
url=agency_dict['agency_url'],
timezone=agency_dict['agency_timezone'],
lang=agency_dict.get('agency_lang', "fr"),
phone=agency_dict.get('agency_phone', ""),
email=agency_dict.get('agency_email', ""),
gtfs_feed=gtfs_feed,
)
agencies.append(agency)
if agencies and not dry_run:
Agency.objects.bulk_create(agencies,
update_conflicts=True,
update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email'],
update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email',
'gtfs_feed'],
unique_fields=['id'])
agencies.clear()
@ -91,8 +84,10 @@ class Command(BaseCommand):
for stop_dict in csv.DictReader(read_file("stops.txt")):
stop_dict: dict
stop_id = stop_dict['stop_id']
if transport_type in ["ES", "TI", "RENFE"]:
stop_id = f"{transport_type}-{stop_id}"
stop_id = f"{gtfs_code}-{stop_id}"
parent_station_id = stop_dict.get('parent_station', None)
parent_station_id = f"{gtfs_code}-{parent_station_id}" if parent_station_id else None
stop = Stop(
id=stop_id,
@ -102,13 +97,13 @@ class Command(BaseCommand):
lon=stop_dict['stop_lon'],
zone_id=stop_dict.get('zone_id', ""),
url=stop_dict.get('stop_url', ""),
location_type=stop_dict.get('location_type', 1) or 1,
parent_station_id=stop_dict.get('parent_station', None) or None,
location_type=stop_dict.get('location_type', 0) or 0,
parent_station_id=parent_station_id,
timezone=stop_dict.get('stop_timezone', ""),
wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0),
level_id=stop_dict.get('level_id', ""),
platform_code=stop_dict.get('platform_code', ""),
transport_type=transport_type,
gtfs_feed=gtfs_feed,
)
stops.append(stop)
@ -119,7 +114,7 @@ class Command(BaseCommand):
update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url',
'location_type', 'parent_station_id', 'timezone',
'wheelchair_boarding', 'level_id', 'platform_code',
'transport_type'],
'gtfs_feed'],
unique_fields=['id'])
stops.clear()
@ -127,11 +122,10 @@ class Command(BaseCommand):
for route_dict in csv.DictReader(read_file("routes.txt")):
route_dict: dict
route_id = route_dict['route_id']
if transport_type == "TI":
route_id = f"{transport_type}-{route_id}"
route_id = f"{gtfs_code}-{route_id}"
route = Route(
id=route_id,
agency_id=route_dict['agency_id'],
agency_id=f"{gtfs_code}-{route_dict['agency_id']}",
short_name=route_dict['route_short_name'],
long_name=route_dict['route_long_name'],
desc=route_dict.get('route_desc', ""),
@ -139,7 +133,7 @@ class Command(BaseCommand):
url=route_dict.get('route_url', ""),
color=route_dict.get('route_color', ""),
text_color=route_dict.get('route_text_color', ""),
transport_type=transport_type,
gtfs_feed=gtfs_feed,
)
routes.append(route)
@ -148,7 +142,7 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'transport_type'],
'gtfs_feed'],
unique_fields=['id'])
routes.clear()
if routes and not dry_run:
@ -156,17 +150,17 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'transport_type'],
'gtfs_feed'],
unique_fields=['id'])
routes.clear()
Calendar.objects.filter(transport_type=transport_type).delete()
Calendar.objects.filter(gtfs_feed=gtfs_feed).delete()
calendars = {}
if "calendar.txt" in zipfile.namelist():
for calendar_dict in csv.DictReader(read_file("calendar.txt")):
calendar_dict: dict
calendar = Calendar(
id=f"{transport_type}-{calendar_dict['service_id']}",
id=f"{gtfs_code}-{calendar_dict['service_id']}",
monday=calendar_dict['monday'],
tuesday=calendar_dict['tuesday'],
wednesday=calendar_dict['wednesday'],
@ -176,7 +170,7 @@ class Command(BaseCommand):
sunday=calendar_dict['sunday'],
start_date=calendar_dict['start_date'],
end_date=calendar_dict['end_date'],
transport_type=transport_type,
gtfs_feed=gtfs_feed,
)
calendars[calendar.id] = calendar
@ -185,14 +179,14 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saturday', 'sunday', 'start_date',
'end_date', 'transport_type'],
'end_date', 'gtfs_feed'],
unique_fields=['id'])
calendars.clear()
if calendars and not dry_run:
Calendar.objects.bulk_create(calendars.values(), update_conflicts=True,
update_fields=['monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saturday', 'sunday', 'start_date',
'end_date', 'transport_type'],
'end_date', 'gtfs_feed'],
unique_fields=['id'])
calendars.clear()
@ -200,8 +194,8 @@ class Command(BaseCommand):
for calendar_date_dict in csv.DictReader(read_file("calendar_dates.txt")):
calendar_date_dict: dict
calendar_date = CalendarDate(
id=f"{transport_type}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}",
service_id=f"{transport_type}-{calendar_date_dict['service_id']}",
id=f"{gtfs_code}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}",
service_id=f"{gtfs_code}-{calendar_date_dict['service_id']}",
date=calendar_date_dict['date'],
exception_type=calendar_date_dict['exception_type'],
)
@ -209,7 +203,7 @@ class Command(BaseCommand):
if calendar_date.service_id not in calendars:
calendar = Calendar(
id=f"{transport_type}-{calendar_date_dict['service_id']}",
id=f"{gtfs_code}-{calendar_date_dict['service_id']}",
monday=False,
tuesday=False,
wednesday=False,
@ -219,11 +213,11 @@ class Command(BaseCommand):
sunday=False,
start_date=calendar_date_dict['date'],
end_date=calendar_date_dict['date'],
transport_type=transport_type,
gtfs_feed=gtfs_feed,
)
calendars[calendar.id] = calendar
else:
calendar = calendars[f"{transport_type}-{calendar_date_dict['service_id']}"]
calendar = calendars[f"{gtfs_code}-{calendar_date_dict['service_id']}"]
if calendar.start_date > calendar_date.date:
calendar.start_date = calendar_date.date
if calendar.end_date < calendar_date.date:
@ -233,7 +227,7 @@ class Command(BaseCommand):
Calendar.objects.bulk_create(calendars.values(),
batch_size=bulk_size,
update_conflicts=True,
update_fields=['start_date', 'end_date'],
update_fields=['start_date', 'end_date', 'gtfs_feed'],
unique_fields=['id'])
CalendarDate.objects.bulk_create(calendar_dates,
batch_size=bulk_size,
@ -248,22 +242,12 @@ class Command(BaseCommand):
trip_dict: dict
trip_id = trip_dict['trip_id']
route_id = trip_dict['route_id']
if transport_type in ["TGV", "IC", "TER"]:
trip_id, last_update = trip_id.split(':', 1)
last_update = datetime.fromisoformat(last_update)
elif transport_type in ["ES", "RENFE"]:
trip_id = f"{transport_type}-{trip_id}"
last_update = None
elif transport_type == "TI":
trip_id = f"{transport_type}-{trip_id}"
route_id = f"{transport_type}-{route_id}"
last_update = None
else:
last_update = None
trip_id = f"{gtfs_code}-{trip_id}"
route_id = f"{gtfs_code}-{route_id}"
trip = Trip(
id=trip_id,
route_id=route_id,
service_id=f"{transport_type}-{trip_dict['service_id']}",
service_id=f"{gtfs_code}-{trip_dict['service_id']}",
headsign=trip_dict.get('trip_headsign', ""),
short_name=trip_dict.get('trip_short_name', ""),
direction_id=trip_dict.get('direction_id', None) or None,
@ -271,7 +255,7 @@ class Command(BaseCommand):
shape_id=trip_dict.get('shape_id', ""),
wheelchair_accessible=trip_dict.get('wheelchair_accessible', None),
bikes_allowed=trip_dict.get('bikes_allowed', None),
last_update=last_update,
gtfs_feed=gtfs_feed,
)
trips.append(trip)
@ -280,7 +264,7 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed'],
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
unique_fields=['id'])
trips.clear()
if trips and not dry_run:
@ -288,7 +272,7 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed'],
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
unique_fields=['id'])
trips.clear()
@ -297,14 +281,10 @@ class Command(BaseCommand):
stop_time_dict: dict
stop_id = stop_time_dict['stop_id']
if transport_type in ["ES", "TI", "RENFE"]:
stop_id = f"{transport_type}-{stop_id}"
stop_id = f"{gtfs_code}-{stop_id}"
trip_id = stop_time_dict['trip_id']
if transport_type in ["TGV", "IC", "TER"]:
trip_id = trip_id.split(':', 1)[0]
elif transport_type in ["ES", "TI", "RENFE"]:
trip_id = f"{transport_type}-{trip_id}"
trip_id = f"{gtfs_code}-{trip_id}"
arr_time = stop_time_dict['arrival_time']
arr_h, arr_m, arr_s = map(int, arr_time.split(':'))
@ -315,19 +295,16 @@ class Command(BaseCommand):
pickup_type = stop_time_dict.get('pickup_type', 0)
drop_off_type = stop_time_dict.get('drop_off_type', 0)
if transport_type in ["ES", "RENFE", "OBB"]:
if stop_time_dict['stop_sequence'] == "1":
drop_off_type = 1
elif arr_time == dep_time:
pickup_type = 1
elif transport_type == "TI":
if stop_time_dict['stop_sequence'] == "0":
drop_off_type = 1
elif arr_time == dep_time:
pickup_type = 1
if stop_time_dict['stop_sequence'] == "1":
# First stop
drop_off_type = PickupType.NONE
elif arr_time == dep_time:
# Last stop
pickup_type = PickupType.NONE
st = StopTime(
id=f"{trip_id}-{stop_id}-{stop_time_dict['departure_time']}",
id=f"{gtfs_code}-{stop_time_dict['trip_id']}-{stop_time_dict['stop_id']}"
f"-{stop_time_dict['departure_time']}",
trip_id=trip_id,
arrival_time=timedelta(seconds=arr_time),
departure_time=timedelta(seconds=dep_time),
@ -363,14 +340,13 @@ class Command(BaseCommand):
transfer_dict: dict
from_stop_id = transfer_dict['from_stop_id']
to_stop_id = transfer_dict['to_stop_id']
if transport_type in ["ES", "RENFE", "OBB"]:
from_stop_id = f"{transport_type}-{from_stop_id}"
to_stop_id = f"{transport_type}-{to_stop_id}"
from_stop_id = f"{gtfs_code}-{from_stop_id}"
to_stop_id = f"{gtfs_code}-{to_stop_id}"
transfer = Transfer(
id=f"{from_stop_id}-{to_stop_id}",
from_stop_id=transfer_dict['from_stop_id'],
to_stop_id=transfer_dict['to_stop_id'],
id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}",
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
transfer_type=transfer_dict['transfer_type'],
min_transfer_time=transfer_dict['min_transfer_time'],
)
@ -395,6 +371,7 @@ class Command(BaseCommand):
feed_info_dict: dict
FeedInfo.objects.update_or_create(
publisher_name=feed_info_dict['feed_publisher_name'],
gtfs_feed=gtfs_feed,
defaults=dict(
publisher_url=feed_info_dict['feed_publisher_url'],
lang=feed_info_dict['feed_lang'],
@ -403,3 +380,12 @@ class Command(BaseCommand):
version=feed_info_dict.get('feed_version', 1),
)
)
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()

View File

@ -5,8 +5,8 @@ import requests
from django.core.management import BaseCommand
from django.db.models import Q
from sncfgtfs.gtfs_realtime_pb2 import FeedMessage
from sncfgtfs.models import Agency, Calendar, CalendarDate, ExceptionType, LocationType, PickupType, \
from sncfgtfs.gtfs_realtime_pb2 import FeedMessage, TripUpdate as GTFSTripUpdate
from sncfgtfs.models import Agency, Calendar, CalendarDate, ExceptionType, GTFSFeed, LocationType, PickupType, \
Route, RouteType, Stop, StopScheduleRelationship, StopTime, StopTimeUpdate, \
Trip, TripUpdate, TripScheduleRelationship
@ -14,34 +14,33 @@ from sncfgtfs.models import Agency, Calendar, CalendarDate, ExceptionType, Locat
class Command(BaseCommand):
help = "Update the SNCF GTFS Realtime database."
GTFS_RT_FEEDS = {
"TGV": "https://proxy.transport.data.gouv.fr/resource/sncf-tgv-gtfs-rt-trip-updates",
"IC": "https://proxy.transport.data.gouv.fr/resource/sncf-ic-gtfs-rt-trip-updates",
"TER": "https://proxy.transport.data.gouv.fr/resource/sncf-ter-gtfs-rt-trip-updates",
"TI": "https://thello.axelor.com/public/gtfs/GTFS-RT.bin",
}
def add_arguments(self, parser):
parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode")
def handle(self, debug=False, *args, **options):
for feed_type, feed_url in self.GTFS_RT_FEEDS.items():
self.stdout.write(f"Updating {feed_type} feed...")
def handle(self, debug: bool = False, verbosity: int = 1, *args, **options):
for gtfs_feed in GTFSFeed.objects.all():
if not gtfs_feed.rt_feed_url:
if verbosity >= 2:
self.stdout.write(self.style.WARNING(f"No GTFS-RT feed found for {gtfs_feed}."))
continue
self.stdout.write(f"Updating GTFS-RT feed for {gtfs_feed}")
gtfs_code = gtfs_feed.code
feed_message = FeedMessage()
feed_message.ParseFromString(requests.get(feed_url).content)
feed_message.ParseFromString(requests.get(gtfs_feed.rt_feed_url, allow_redirects=True).content)
stop_times_updates = []
if debug:
with open(f'feed_message-{feed_type}.txt', 'w') as f:
with open(f'feed_message-{gtfs_code}.txt', 'w') as f:
f.write(str(feed_message))
for entity in feed_message.entity:
if entity.HasField("trip_update"):
trip_update = entity.trip_update
trip_id = trip_update.trip.trip_id
if feed_type in ["TGV", "IC", "TER"]:
trip_id = trip_id.split(":", 1)[0]
trip_id = f"{gtfs_code}-{trip_id}"
start_date = date(year=int(trip_update.trip.start_date[:4]),
month=int(trip_update.trip.start_date[4:6]),
@ -50,7 +49,7 @@ class Command(BaseCommand):
if trip_update.trip.schedule_relationship == TripScheduleRelationship.ADDED:
# C'est un trajet nouveau. On crée le trajet associé.
self.create_trip(trip_update, trip_id, start_dt, feed_type)
self.create_trip(trip_update, trip_id, start_dt, gtfs_feed)
if not Trip.objects.filter(id=trip_id).exists():
self.stdout.write(f"Trip {trip_id} does not exist in the GTFS feed.")
@ -68,22 +67,19 @@ class Command(BaseCommand):
for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update):
stop_id = stop_time_update.stop_id
if stop_id.startswith('StopArea:'):
# On est dans le cadre d'une gare. On cherche le quai associé.
if StopTime.objects.filter(trip_id=trip_id, stop__parent_station_id=stop_id).exists():
# U
stop = StopTime.objects.get(trip_id=trip_id, stop__parent_station_id=stop_id).stop
stop_id = f"{gtfs_code}-{stop_id}"
if StopTime.objects.filter(trip_id=trip_id, stop=stop_id).exists():
st = StopTime.objects.filter(trip_id=trip_id, stop=stop_id)
if st.count() > 1:
st = st.get(stop_sequence=stop_sequence)
else:
stops = [s for s in Stop.objects.filter(parent_station_id=stop_id).all()
for s2 in StopTime.objects.filter(trip_id=trip_id).all()
if s.stop_type in s2.stop.stop_type
or s2.stop.stop_type in s.stop_type]
stop = stops[0] if stops else Stop.objects.get(id=stop_id)
st, _created = StopTime.objects.update_or_create(
id=f"{trip_id}-{stop.id}",
st = st.first()
else:
# Stop is added
st = StopTime.objects.create(
id=f"{trip_id}-{stop_time_update.stop_id}",
trip_id=trip_id,
stop_id=stop.id,
stop_id=stop_id,
defaults={
"stop_sequence": stop_sequence,
"arrival_time": datetime.fromtimestamp(stop_time_update.arrival.time,
@ -96,23 +92,16 @@ class Command(BaseCommand):
else PickupType.NONE),
}
)
elif stop_time_update.schedule_relationship == StopScheduleRelationship.SKIPPED:
st = StopTime.objects.get(Q(stop=stop_id) | Q(stop__parent_station_id=stop_id),
trip_id=trip_id)
if stop_time_update.schedule_relationship == StopScheduleRelationship.SKIPPED:
if st.pickup_type != PickupType.NONE or st.drop_off_type != PickupType.NONE:
st.pickup_type = PickupType.NONE
st.drop_off_type = PickupType.NONE
st.save()
else:
qs = StopTime.objects.filter(Q(stop=stop_id) | Q(stop__parent_station_id=stop_id),
trip_id=trip_id)
if qs.count() == 1:
st = qs.first()
else:
st = qs.get(stop_sequence=stop_sequence)
if st.stop_sequence != stop_sequence:
st.stop_sequence = stop_sequence
st.save()
if st.stop_sequence != stop_sequence:
st.stop_sequence = stop_sequence
st.save()
st_update = StopTimeUpdate(
trip_update=tu,
@ -136,73 +125,22 @@ class Command(BaseCommand):
'departure_delay', 'departure_time'],
unique_fields=['trip_update', 'stop_time'])
def create_trip(self, trip_update, trip_id, start_dt, feed_type):
def create_trip(self, trip_update: GTFSTripUpdate, trip_id: str, start_dt: datetime, gtfs_feed: GTFSFeed) -> None:
headsign = trip_id[5:-1]
trip_qs = Trip.objects.all()
trip_ids = trip_qs.values_list('id', flat=True)
gtfs_code = gtfs_feed.code
first_stop_queryset = StopTime.objects.filter(
stop__parent_station_id=trip_update.stop_time_update[0].stop_id,
).values('trip_id')
last_stop_queryset = StopTime.objects.filter(
stop__parent_station_id=trip_update.stop_time_update[-1].stop_id,
).values('trip_id')
trip_ids = trip_ids.intersection(first_stop_queryset).intersection(last_stop_queryset)
# print(trip_id, trip_ids)
for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update):
stop_id = stop_time_update.stop_id
st_queryset = StopTime.objects.filter(stop__parent_station_id=stop_id)
if stop_sequence == 0:
st_queryset = st_queryset.filter(stop_sequence=0)
# print(stop_sequence, Stop.objects.get(id=stop_id).name, stop_time_update)
# print(trip_ids)
# print(st_queryset.values('trip_id').all())
trip_ids_restrict = trip_ids.intersection(st_queryset.values('trip_id'))
if trip_ids_restrict:
trip_ids = trip_ids_restrict
else:
stop = Stop.objects.get(id=stop_id)
self.stdout.write(self.style.WARNING(f"Warning: No trip is found passing by stop "
f"{stop.name} ({stop_id})"))
trip_ids = set(trip_ids)
route_ids = set(Trip.objects.filter(id__in=trip_ids).values_list('route_id', flat=True))
self.stdout.write(f"{len(route_ids)} routes found on trip for new train {headsign}")
if not route_ids:
origin_id = trip_update.stop_time_update[0].stop_id
origin = Stop.objects.get(id=origin_id)
destination_id = trip_update.stop_time_update[-1].stop_id
destination = Stop.objects.get(id=destination_id)
trip_name = f"{origin.name} - {destination.name}"
trip_reverse_name = f"{destination.name} - {origin.name}"
route_qs = Route.objects.filter(long_name=trip_name, transport_type=feed_type)
route_reverse_qs = Route.objects.filter(long_name=trip_reverse_name,
transport_type=feed_type)
if route_qs.exists():
route_ids = set(route_qs.values_list('id', flat=True))
elif route_reverse_qs.exists():
route_ids = set(route_reverse_qs.values_list('id', flat=True))
else:
self.stdout.write(f"Route not found for trip {trip_id} ({trip_name}). Creating new one")
route = Route.objects.create(
id=f"CREATED-{trip_name}",
agency=Agency.objects.filter(routes__transport_type=feed_type).first(),
transport_type=feed_type,
type=RouteType.RAIL,
short_name=trip_name,
long_name=trip_name,
)
route_ids = {route.id}
self.stdout.write(f"Route {route.id} created for trip {trip_id} ({trip_name})")
elif len(route_ids) > 1:
self.stdout.write(f"Multiple routes found for trip {trip_id}.")
self.stdout.write(", ".join(route_ids))
route_id = route_ids.pop()
route, _created = Route.objects.get_or_create(
id=f"{gtfs_code}-ADDED-{headsign}",
gtfs_feed=gtfs_feed,
type=RouteType.RAIL,
short_name="ADDED",
long_name="ADDED ROUTE",
)
Calendar.objects.update_or_create(
id=f"{feed_type}-new-{headsign}",
id=f"{gtfs_code}-ADDED-{headsign}",
defaults={
"transport_type": feed_type,
"gtfs_feed": gtfs_feed,
"monday": False,
"tuesday": False,
"wednesday": False,
@ -215,9 +153,9 @@ class Command(BaseCommand):
}
)
CalendarDate.objects.update_or_create(
id=f"{feed_type}-{headsign}-{trip_update.trip.start_date}",
id=f"{gtfs_code}-ADDED-{headsign}-{trip_update.trip.start_date}",
defaults={
"service_id": f"{feed_type}-new-{headsign}",
"service_id": f"{gtfs_code}-ADDED-{headsign}",
"date": trip_update.trip.start_date,
"exception_type": ExceptionType.ADDED,
}
@ -225,32 +163,17 @@ class Command(BaseCommand):
Trip.objects.update_or_create(
id=trip_id,
defaults={
"route_id": route_id,
"service_id": f"{feed_type}-new-{headsign}",
"route_id": route.id,
"service_id": f"{gtfs_code}-ADDED-{headsign}",
"headsign": headsign,
"direction_id": trip_update.trip.direction_id,
"gtfs_feed": gtfs_feed,
}
)
sample_trip = Trip.objects.filter(id__in=trip_ids, route_id=route_id)
sample_trip = sample_trip.first() if sample_trip.exists() else None
for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update):
stop_id = stop_time_update.stop_id
stop = Stop.objects.get(id=stop_id)
if stop.location_type == LocationType.STATION:
if not StopTime.objects.filter(trip_id=trip_id).exists():
if sample_trip:
stop = StopTime.objects.get(trip_id=sample_trip.id,
stop__parent_station_id=stop_id).stop
elif StopTime.objects.filter(trip_id=trip_id, stop__parent_station_id=stop_id).exists():
stop = StopTime.objects.get(trip_id=trip_id, stop__parent_station_id=stop_id).stop
else:
stops = [s for s in Stop.objects.filter(parent_station_id=stop_id).all()
for s2 in StopTime.objects.filter(trip_id=trip_id).all()
if s.stop_type in s2.stop.stop_type
or s2.stop.stop_type in s.stop_type]
stop = stops[0] if stops else stop
stop_id = stop.id
stop_id = f"{gtfs_code}-{stop_id}"
arr_time = datetime.fromtimestamp(stop_time_update.arrival.time,
tz=ZoneInfo("Europe/Paris")) - start_dt
@ -263,7 +186,7 @@ class Command(BaseCommand):
and stop_sequence < len(trip_update.stop_time_update) - 1 else PickupType.NONE
StopTime.objects.update_or_create(
id=f"{trip_id}-{stop_id}",
id=f"{trip_id}-{stop_time_update.stop_id}",
trip_id=trip_id,
defaults={
"stop_id": stop_id,