Restructurate GTFS feeds into dedicated models

This commit is contained in:
2024-05-09 19:28:19 +02:00
parent 820fc0cc19
commit 11949228ee
11 changed files with 1594 additions and 950 deletions

View File

@ -2,27 +2,18 @@ import csv
from datetime import datetime, timedelta
from io import BytesIO
from zipfile import ZipFile
from zoneinfo import ZoneInfo
import requests
from django.core.management import BaseCommand
from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, Route, Stop, StopTime, Transfer, Trip
from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, Transfer, Trip, \
PickupType
class Command(BaseCommand):
help = "Update the SNCF GTFS database."
GTFS_FEEDS = {
"TGV": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export_gtfs_voyages.zip",
"IC": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-intercites-gtfs-last.zip",
"TER": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip",
"TN": "https://eu.ftp.opendatasoft.com/sncf/gtfs/transilien-gtfs.zip",
"ES": "https://www.data.gouv.fr/fr/datasets/r/9089b550-696e-4ae0-87b5-40ea55a14292",
"TI": "https://thello.axelor.com/public/gtfs/gtfs.zip",
"RENFE": "https://ssl.renfe.com/gtransit/Fichero_AV_LD/google_transit.zip",
"OBB": "https://static.oebb.at/open-data/soll-fahrplan-gtfs/GTFS_OP_2024_obb.zip",
}
def add_arguments(self, parser):
parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode")
parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.")
@ -30,36 +21,36 @@ class Command(BaseCommand):
help="Do not update the database, only print what would be done.")
parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.")
def handle(self, *args, **options):
bulk_size = options['bulk_size']
dry_run = options['dry_run']
force = options['force']
def handle(self, debug: bool = False, bulk_size: int = 100, dry_run: bool = False, force: bool = False,
verbosity: int = 1, *args, **options):
if dry_run:
self.stdout.write(self.style.WARNING("Dry run mode activated."))
if not FeedInfo.objects.exists():
last_update_date = "1970-01-01"
else:
last_update_date = FeedInfo.objects.get(publisher_name='SNCF_default').version
for url in self.GTFS_FEEDS.values():
resp = requests.head(url)
if "Last-Modified" not in resp.headers:
continue
last_modified = resp.headers["Last-Modified"]
last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z")
if last_modified.date().isoformat() > last_update_date:
break
else:
if not force:
self.stdout.write(self.style.WARNING("Database already up-to-date."))
return
self.stdout.write("Updating database...")
for transport_type, feed_url in self.GTFS_FEEDS.items():
self.stdout.write(f"Downloading {transport_type} GTFS feed...")
with ZipFile(BytesIO(requests.get(feed_url).content)) as zipfile:
for gtfs_feed in GTFSFeed.objects.all():
gtfs_code = gtfs_feed.code
if not force:
# Check if the source file was updated
resp = requests.head(gtfs_feed.feed_url, allow_redirects=True)
if 'ETag' in resp.headers and gtfs_feed.etag:
if resp.headers['ETag'] == gtfs_feed.etag:
if verbosity >= 1:
self.stdout.write(self.style.WARNING(f"Database is already up-to-date for {gtfs_feed}."))
continue
if 'Last-Modified' in resp.headers and gtfs_feed.last_modified:
last_modified = resp.headers['Last-Modified']
last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
if last_modified <= gtfs_feed.last_modified:
if verbosity >= 1:
self.stdout.write(self.style.WARNING(f"Database is already up-to-date for {gtfs_feed}."))
continue
self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...")
resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True)
with ZipFile(BytesIO(resp.content)) as zipfile:
def read_file(filename):
lines = zipfile.read(filename).decode().replace('\ufeff', '').splitlines()
return [line.strip() for line in lines]
@ -67,23 +58,25 @@ class Command(BaseCommand):
agencies = []
for agency_dict in csv.DictReader(read_file("agency.txt")):
agency_dict: dict
if transport_type == "ES" \
and agency_dict['agency_id'] != 'ES' and agency_dict['agency_id'] != 'ER':
continue
# if gtfs_code == "FR-EUROSTAR" \
# and agency_dict['agency_id'] != 'ES' and agency_dict['agency_id'] != 'ER':
# continue
agency = Agency(
id=agency_dict['agency_id'],
id=f"{gtfs_code}-{agency_dict['agency_id']}",
name=agency_dict['agency_name'],
url=agency_dict['agency_url'],
timezone=agency_dict['agency_timezone'],
lang=agency_dict.get('agency_lang', "fr"),
phone=agency_dict.get('agency_phone', ""),
email=agency_dict.get('agency_email', ""),
gtfs_feed=gtfs_feed,
)
agencies.append(agency)
if agencies and not dry_run:
Agency.objects.bulk_create(agencies,
update_conflicts=True,
update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email'],
update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email',
'gtfs_feed'],
unique_fields=['id'])
agencies.clear()
@ -91,8 +84,10 @@ class Command(BaseCommand):
for stop_dict in csv.DictReader(read_file("stops.txt")):
stop_dict: dict
stop_id = stop_dict['stop_id']
if transport_type in ["ES", "TI", "RENFE"]:
stop_id = f"{transport_type}-{stop_id}"
stop_id = f"{gtfs_code}-{stop_id}"
parent_station_id = stop_dict.get('parent_station', None)
parent_station_id = f"{gtfs_code}-{parent_station_id}" if parent_station_id else None
stop = Stop(
id=stop_id,
@ -102,13 +97,13 @@ class Command(BaseCommand):
lon=stop_dict['stop_lon'],
zone_id=stop_dict.get('zone_id', ""),
url=stop_dict.get('stop_url', ""),
location_type=stop_dict.get('location_type', 1) or 1,
parent_station_id=stop_dict.get('parent_station', None) or None,
location_type=stop_dict.get('location_type', 0) or 0,
parent_station_id=parent_station_id,
timezone=stop_dict.get('stop_timezone', ""),
wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0),
level_id=stop_dict.get('level_id', ""),
platform_code=stop_dict.get('platform_code', ""),
transport_type=transport_type,
gtfs_feed=gtfs_feed,
)
stops.append(stop)
@ -119,7 +114,7 @@ class Command(BaseCommand):
update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url',
'location_type', 'parent_station_id', 'timezone',
'wheelchair_boarding', 'level_id', 'platform_code',
'transport_type'],
'gtfs_feed'],
unique_fields=['id'])
stops.clear()
@ -127,11 +122,10 @@ class Command(BaseCommand):
for route_dict in csv.DictReader(read_file("routes.txt")):
route_dict: dict
route_id = route_dict['route_id']
if transport_type == "TI":
route_id = f"{transport_type}-{route_id}"
route_id = f"{gtfs_code}-{route_id}"
route = Route(
id=route_id,
agency_id=route_dict['agency_id'],
agency_id=f"{gtfs_code}-{route_dict['agency_id']}",
short_name=route_dict['route_short_name'],
long_name=route_dict['route_long_name'],
desc=route_dict.get('route_desc', ""),
@ -139,7 +133,7 @@ class Command(BaseCommand):
url=route_dict.get('route_url', ""),
color=route_dict.get('route_color', ""),
text_color=route_dict.get('route_text_color', ""),
transport_type=transport_type,
gtfs_feed=gtfs_feed,
)
routes.append(route)
@ -148,7 +142,7 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'transport_type'],
'gtfs_feed'],
unique_fields=['id'])
routes.clear()
if routes and not dry_run:
@ -156,17 +150,17 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'transport_type'],
'gtfs_feed'],
unique_fields=['id'])
routes.clear()
Calendar.objects.filter(transport_type=transport_type).delete()
Calendar.objects.filter(gtfs_feed=gtfs_feed).delete()
calendars = {}
if "calendar.txt" in zipfile.namelist():
for calendar_dict in csv.DictReader(read_file("calendar.txt")):
calendar_dict: dict
calendar = Calendar(
id=f"{transport_type}-{calendar_dict['service_id']}",
id=f"{gtfs_code}-{calendar_dict['service_id']}",
monday=calendar_dict['monday'],
tuesday=calendar_dict['tuesday'],
wednesday=calendar_dict['wednesday'],
@ -176,7 +170,7 @@ class Command(BaseCommand):
sunday=calendar_dict['sunday'],
start_date=calendar_dict['start_date'],
end_date=calendar_dict['end_date'],
transport_type=transport_type,
gtfs_feed=gtfs_feed,
)
calendars[calendar.id] = calendar
@ -185,14 +179,14 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saturday', 'sunday', 'start_date',
'end_date', 'transport_type'],
'end_date', 'gtfs_feed'],
unique_fields=['id'])
calendars.clear()
if calendars and not dry_run:
Calendar.objects.bulk_create(calendars.values(), update_conflicts=True,
update_fields=['monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saturday', 'sunday', 'start_date',
'end_date', 'transport_type'],
'end_date', 'gtfs_feed'],
unique_fields=['id'])
calendars.clear()
@ -200,8 +194,8 @@ class Command(BaseCommand):
for calendar_date_dict in csv.DictReader(read_file("calendar_dates.txt")):
calendar_date_dict: dict
calendar_date = CalendarDate(
id=f"{transport_type}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}",
service_id=f"{transport_type}-{calendar_date_dict['service_id']}",
id=f"{gtfs_code}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}",
service_id=f"{gtfs_code}-{calendar_date_dict['service_id']}",
date=calendar_date_dict['date'],
exception_type=calendar_date_dict['exception_type'],
)
@ -209,7 +203,7 @@ class Command(BaseCommand):
if calendar_date.service_id not in calendars:
calendar = Calendar(
id=f"{transport_type}-{calendar_date_dict['service_id']}",
id=f"{gtfs_code}-{calendar_date_dict['service_id']}",
monday=False,
tuesday=False,
wednesday=False,
@ -219,11 +213,11 @@ class Command(BaseCommand):
sunday=False,
start_date=calendar_date_dict['date'],
end_date=calendar_date_dict['date'],
transport_type=transport_type,
gtfs_feed=gtfs_feed,
)
calendars[calendar.id] = calendar
else:
calendar = calendars[f"{transport_type}-{calendar_date_dict['service_id']}"]
calendar = calendars[f"{gtfs_code}-{calendar_date_dict['service_id']}"]
if calendar.start_date > calendar_date.date:
calendar.start_date = calendar_date.date
if calendar.end_date < calendar_date.date:
@ -233,7 +227,7 @@ class Command(BaseCommand):
Calendar.objects.bulk_create(calendars.values(),
batch_size=bulk_size,
update_conflicts=True,
update_fields=['start_date', 'end_date'],
update_fields=['start_date', 'end_date', 'gtfs_feed'],
unique_fields=['id'])
CalendarDate.objects.bulk_create(calendar_dates,
batch_size=bulk_size,
@ -248,22 +242,12 @@ class Command(BaseCommand):
trip_dict: dict
trip_id = trip_dict['trip_id']
route_id = trip_dict['route_id']
if transport_type in ["TGV", "IC", "TER"]:
trip_id, last_update = trip_id.split(':', 1)
last_update = datetime.fromisoformat(last_update)
elif transport_type in ["ES", "RENFE"]:
trip_id = f"{transport_type}-{trip_id}"
last_update = None
elif transport_type == "TI":
trip_id = f"{transport_type}-{trip_id}"
route_id = f"{transport_type}-{route_id}"
last_update = None
else:
last_update = None
trip_id = f"{gtfs_code}-{trip_id}"
route_id = f"{gtfs_code}-{route_id}"
trip = Trip(
id=trip_id,
route_id=route_id,
service_id=f"{transport_type}-{trip_dict['service_id']}",
service_id=f"{gtfs_code}-{trip_dict['service_id']}",
headsign=trip_dict.get('trip_headsign', ""),
short_name=trip_dict.get('trip_short_name', ""),
direction_id=trip_dict.get('direction_id', None) or None,
@ -271,7 +255,7 @@ class Command(BaseCommand):
shape_id=trip_dict.get('shape_id', ""),
wheelchair_accessible=trip_dict.get('wheelchair_accessible', None),
bikes_allowed=trip_dict.get('bikes_allowed', None),
last_update=last_update,
gtfs_feed=gtfs_feed,
)
trips.append(trip)
@ -280,7 +264,7 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed'],
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
unique_fields=['id'])
trips.clear()
if trips and not dry_run:
@ -288,7 +272,7 @@ class Command(BaseCommand):
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed'],
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
unique_fields=['id'])
trips.clear()
@ -297,14 +281,10 @@ class Command(BaseCommand):
stop_time_dict: dict
stop_id = stop_time_dict['stop_id']
if transport_type in ["ES", "TI", "RENFE"]:
stop_id = f"{transport_type}-{stop_id}"
stop_id = f"{gtfs_code}-{stop_id}"
trip_id = stop_time_dict['trip_id']
if transport_type in ["TGV", "IC", "TER"]:
trip_id = trip_id.split(':', 1)[0]
elif transport_type in ["ES", "TI", "RENFE"]:
trip_id = f"{transport_type}-{trip_id}"
trip_id = f"{gtfs_code}-{trip_id}"
arr_time = stop_time_dict['arrival_time']
arr_h, arr_m, arr_s = map(int, arr_time.split(':'))
@ -315,19 +295,16 @@ class Command(BaseCommand):
pickup_type = stop_time_dict.get('pickup_type', 0)
drop_off_type = stop_time_dict.get('drop_off_type', 0)
if transport_type in ["ES", "RENFE", "OBB"]:
if stop_time_dict['stop_sequence'] == "1":
drop_off_type = 1
elif arr_time == dep_time:
pickup_type = 1
elif transport_type == "TI":
if stop_time_dict['stop_sequence'] == "0":
drop_off_type = 1
elif arr_time == dep_time:
pickup_type = 1
if stop_time_dict['stop_sequence'] == "1":
# First stop
drop_off_type = PickupType.NONE
elif arr_time == dep_time:
# Last stop
pickup_type = PickupType.NONE
st = StopTime(
id=f"{trip_id}-{stop_id}-{stop_time_dict['departure_time']}",
id=f"{gtfs_code}-{stop_time_dict['trip_id']}-{stop_time_dict['stop_id']}"
f"-{stop_time_dict['departure_time']}",
trip_id=trip_id,
arrival_time=timedelta(seconds=arr_time),
departure_time=timedelta(seconds=dep_time),
@ -363,14 +340,13 @@ class Command(BaseCommand):
transfer_dict: dict
from_stop_id = transfer_dict['from_stop_id']
to_stop_id = transfer_dict['to_stop_id']
if transport_type in ["ES", "RENFE", "OBB"]:
from_stop_id = f"{transport_type}-{from_stop_id}"
to_stop_id = f"{transport_type}-{to_stop_id}"
from_stop_id = f"{gtfs_code}-{from_stop_id}"
to_stop_id = f"{gtfs_code}-{to_stop_id}"
transfer = Transfer(
id=f"{from_stop_id}-{to_stop_id}",
from_stop_id=transfer_dict['from_stop_id'],
to_stop_id=transfer_dict['to_stop_id'],
id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}",
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
transfer_type=transfer_dict['transfer_type'],
min_transfer_time=transfer_dict['min_transfer_time'],
)
@ -395,6 +371,7 @@ class Command(BaseCommand):
feed_info_dict: dict
FeedInfo.objects.update_or_create(
publisher_name=feed_info_dict['feed_publisher_name'],
gtfs_feed=gtfs_feed,
defaults=dict(
publisher_url=feed_info_dict['feed_publisher_url'],
lang=feed_info_dict['feed_lang'],
@ -403,3 +380,12 @@ class Command(BaseCommand):
version=feed_info_dict.get('feed_version', 1),
)
)
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()