acme-dns-zonefile/acme-dns-zonefile.py

225 lines
7.0 KiB
Python

# Copyright (C) 2023 Noisytoot
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import json
import configparser
import base64
import os
import time
import sys
import secrets
import signal
import subprocess
import sqlite3
import uuid
import bottle
from bottle import run, default_app, get, post, request, response
import passlib
from passlib.hash import bcrypt
config = configparser.ConfigParser()
config.read('acme-dns-zonefile.conf')
bind_host = config['bind'].get('host', 'localhost')
bind_port = config['bind'].get('port', 8080)
domain = config['general']['domain']
nsname = config['general'].get('nsname', domain)
nsadmin = config['general']['nsadmin']
zonefile = config['general']['zonefile']
dbfile = config['general'].get('dbfile', 'acme-dns-zonefile.db')
records = config['general'].get('records', '')
token_length_min = min(255, max(1, int(config['general'].get('token_length_min', 43))))
token_length_max = max(token_length_min, min(255, max(1, int(config['general'].get('token_length_max', 43)))))
password_length = int(config['general'].get('password_length', 30))
reload_method = config['reload'].get('method', 'none')
disable_registration = config.getboolean('api', 'disable_registration', fallback=False)
disable_reload = config.getboolean('api', 'disable_reload', fallback=False)
con = sqlite3.connect(dbfile)
cur = con.cursor()
def check_db_version():
res = cur.execute("SELECT Value FROM acmedns WHERE name='db_version' LIMIT 1")
version = res.fetchone()
if version is None:
cur.execute("INSERT INTO acmedns VALUES ('db_version', '1')")
elif version[0] != '1':
print(f"Error: Unexpected db version {version[0]}, aborting.", file=sys.stderr)
sys.exit(1)
cur.execute("""\
CREATE TABLE IF NOT EXISTS acmedns(
Name TEXT,
Value TEXT
);""")
cur.execute("""\
CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
Password TEXT UNIQUE NOT NULL,
Subdomain TEXT UNIQUE NOT NULL,
AllowFrom TEXT
);""")
cur.execute("""\
CREATE TABLE IF NOT EXISTS txt(
Subdomain TEXT NOT NULL,
Value TEXT NOT NULL DEFAULT '',
LastUpdate INT
);""")
check_db_version()
con.commit()
def commentify(string):
return '\n'.join(['# ' + line for line in string.split('\n')])
source_str = commentify(f"""\
Python {sys.version} on {sys.platform}
bottle {bottle.__version__}
passlib {passlib.__version__}""") + '\n\n'
with open(__file__) as source_file:
source_str += source_file.read()
def is_valid_token(s):
if not token_length_min <= len(s) <= token_length_max:
return False
# base64url alphabet
alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_="
for c in s:
if c not in alphabet:
return False
return True
def write_zonefile():
zone = f"""\
$ORIGIN {domain}.
{domain}. 3600 SOA {nsname}. {nsadmin}. (
{int(time.time())} ; serial
1800 ; refresh
3600 ; retry
86400 ; expire
3600 ) ; minimum TTL
{records}
"""
res = cur.execute("SELECT Subdomain, Value FROM txt")
record = res.fetchone()
while record is not None:
(subdomain, value) = record
# it should already be valid, but double-check anyway
if is_valid_token(value):
zone += f'{subdomain}.{domain}.\t1\tIN\tTXT\t"{value}"\n'
record = res.fetchone()
with open(zonefile, 'w') as f:
f.write(zone)
def set_txt(subdomain, txt):
cur.execute("""\
UPDATE txt SET Value=?, LastUpdate=?
WHERE rowid=(
SELECT rowid FROM txt WHERE Subdomain=? ORDER BY LastUpdate LIMIT 1
)""", (txt, int(time.time()), subdomain))
con.commit()
def reload_dns_server():
if reload_method == 'signal':
pidfile = config['reload']['pidfile']
with open(pidfile) as f:
pid = int(f.read())
try:
os.kill(pid, signal.SIGHUP)
except PermissionError:
print("Warning: Failed to reload DNS server: operation not permitted", file=sys.stderr)
elif reload_method == 'exec':
command = config['reload']['command']
subprocess.Popen([command], shell=True)
@post('/register')
def register():
if disable_registration:
response.status = 403
return { "error": "Registration endpoint disabled" }
response.status = 201
username = str(uuid.uuid4())
password = secrets.token_urlsafe(password_length)
passhash = bcrypt.hash(password)
subdomain = str(uuid.uuid4())
cur.execute("INSERT INTO records VALUES (?, ?, ?, '')", (username, passhash, subdomain))
# two rows for subdomain in txt table
cur.execute("INSERT INTO txt (Subdomain, LastUpdate) VALUES (?, 0)", (subdomain,))
cur.execute("INSERT INTO txt (Subdomain, LastUpdate) VALUES (?, 0)", (subdomain,))
con.commit()
return {
"allowfrom": [],
"fulldomain": f"{subdomain}.{domain}",
"password": password,
"subdomain": subdomain,
"username": username
}
@post('/update')
def update():
body = request.body.read()
req_username = request.headers.get('X-Api-User')
req_password = request.headers.get('X-Api-Key')
try:
data = json.loads(body)
except json.decoder.JSONDecodeError:
response.status = 400
return { "error": "Invalid JSON" }
req_subdomain = data['subdomain']
req_txt = data['txt']
res = cur.execute("SELECT Password, Subdomain FROM records WHERE Username=? LIMIT 1", (req_username,))
record = res.fetchone()
if record is None:
response.status = 403
return { "error": "Invalid username" }
(db_password, db_subdomain) = record
if req_subdomain != db_subdomain:
response.status = 403
return { "error": "Invalid subdomain" }
if not bcrypt.verify(req_password, db_password):
response.status = 403
return { "error": "Invalid password" }
if not is_valid_token(req_txt):
response.status = 400
return { "error": "Invalid challenge token" }
set_txt(req_subdomain, req_txt)
write_zonefile()
reload_dns_server()
return {
"txt": req_txt
}
@get('/health')
def health():
response.status = 200
@get('/reload')
def reload():
if disable_reload:
response.status = 403
return { "error": "Reload endpoint disabled" }
write_zonefile()
reload_dns_server()
@get('/')
def source():
response.content_type = 'text/plain; charset=utf-8'
return source_str
if __name__ == "__main__":
run(host=bind_host, port=bind_port)
else:
app = application = default_app()