224 lines
7 KiB
Python
224 lines
7 KiB
Python
# Copyright (C) 2023 Noisytoot
|
|
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import json
|
|
import configparser
|
|
import base64
|
|
import os
|
|
import time
|
|
import sys
|
|
import secrets
|
|
import signal
|
|
import subprocess
|
|
import sqlite3
|
|
import uuid
|
|
import bottle
|
|
from bottle import run, default_app, get, post, request, response
|
|
import passlib
|
|
from passlib.hash import bcrypt
|
|
|
|
config = configparser.ConfigParser()
|
|
config.read('acme-dns-zonefile.conf')
|
|
|
|
bind_host = config['bind'].get('host', 'localhost')
|
|
bind_port = config['bind'].get('port', 8080)
|
|
domain = config['general']['domain']
|
|
nsname = config['general'].get('nsname', domain)
|
|
nsadmin = config['general']['nsadmin']
|
|
zonefile = config['general']['zonefile']
|
|
dbfile = config['general'].get('dbfile', 'acme-dns-zonefile.db')
|
|
records = config['general'].get('records', '')
|
|
token_length_min = min(255, max(1, int(config['general'].get('token_length_min', 43))))
|
|
token_length_max = max(token_length_min, min(255, max(1, int(config['general'].get('token_length_max', 43)))))
|
|
password_length = int(config['general'].get('password_length', 30))
|
|
reload_method = config['reload'].get('method', 'none')
|
|
disable_registration = config.getboolean('api', 'disable_registration', fallback=False)
|
|
disable_reload = config.getboolean('api', 'disable_reload', fallback=False)
|
|
|
|
con = sqlite3.connect(dbfile)
|
|
cur = con.cursor()
|
|
|
|
def check_db_version():
|
|
res = cur.execute("SELECT Value FROM acmedns WHERE name='db_version' LIMIT 1")
|
|
version = res.fetchone()
|
|
if version is None:
|
|
cur.execute("INSERT INTO acmedns VALUES ('db_version', '1')")
|
|
elif version[0] != '1':
|
|
print(f"Error: Unexpected db version {version[0]}, aborting.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
cur.execute("""\
|
|
CREATE TABLE IF NOT EXISTS acmedns(
|
|
Name TEXT,
|
|
Value TEXT
|
|
);""")
|
|
cur.execute("""\
|
|
CREATE TABLE IF NOT EXISTS records(
|
|
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
|
Password TEXT UNIQUE NOT NULL,
|
|
Subdomain TEXT UNIQUE NOT NULL,
|
|
AllowFrom TEXT
|
|
);""")
|
|
cur.execute("""\
|
|
CREATE TABLE IF NOT EXISTS txt(
|
|
Subdomain TEXT NOT NULL,
|
|
Value TEXT NOT NULL DEFAULT '',
|
|
LastUpdate INT
|
|
);""")
|
|
check_db_version()
|
|
con.commit()
|
|
|
|
def commentify(string):
|
|
return '\n'.join(['# ' + line for line in string.split('\n')])
|
|
|
|
source_str = commentify(f"""\
|
|
Python {sys.version} on {sys.platform}
|
|
bottle {bottle.__version__}
|
|
passlib {passlib.__version__}""") + '\n\n'
|
|
|
|
with open(__file__) as source_file:
|
|
source_str += source_file.read()
|
|
|
|
def is_valid_token(s):
|
|
if not token_length_min <= len(s) <= token_length_max:
|
|
return False
|
|
# base64url alphabet
|
|
alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_="
|
|
for c in s:
|
|
if c not in alphabet:
|
|
return False
|
|
return True
|
|
|
|
def write_zonefile():
|
|
zone = f"""\
|
|
$ORIGIN {domain}.
|
|
{domain}. 3600 SOA {nsname}. {nsadmin}. (
|
|
{int(time.time())} ; serial
|
|
1800 ; refresh
|
|
3600 ; retry
|
|
86400 ; expire
|
|
3600 ) ; minimum TTL
|
|
{records}
|
|
"""
|
|
res = cur.execute("SELECT Subdomain, Value FROM txt")
|
|
record = res.fetchone()
|
|
while record is not None:
|
|
(subdomain, value) = record
|
|
# it should already be valid, but double-check anyway
|
|
if is_valid_token(value):
|
|
zone += f'{subdomain}.{domain}.\t1\tIN\tTXT\t"{value}"\n'
|
|
record = res.fetchone()
|
|
with open(zonefile, 'w') as f:
|
|
f.write(zone)
|
|
|
|
def set_txt(subdomain, txt):
|
|
cur.execute("""\
|
|
UPDATE txt SET Value=?, LastUpdate=?
|
|
WHERE rowid=(
|
|
SELECT rowid FROM txt WHERE Subdomain=? ORDER BY LastUpdate LIMIT 1
|
|
)""", (txt, int(time.time()), subdomain))
|
|
con.commit()
|
|
|
|
def reload_dns_server():
|
|
if reload_method == 'signal':
|
|
pidfile = config['reload']['pidfile']
|
|
with open(pidfile) as f:
|
|
pid = int(f.read())
|
|
try:
|
|
os.kill(pid, signal.SIGHUP)
|
|
except PermissionError:
|
|
print("Warning: Failed to reload DNS server: operation not permitted", file=sys.stderr)
|
|
elif reload_method == 'exec':
|
|
command = config['reload']['command']
|
|
subprocess.Popen([command], shell=True)
|
|
|
|
@post('/register')
|
|
def register():
|
|
if disable_registration:
|
|
response.status = 403
|
|
return { "error": "Registration endpoint disabled" }
|
|
response.status = 201
|
|
username = str(uuid.uuid4())
|
|
password = secrets.token_urlsafe(password_length)
|
|
passhash = bcrypt.hash(password)
|
|
subdomain = str(uuid.uuid4())
|
|
cur.execute("INSERT INTO records VALUES (?, ?, ?, '')", (username, passhash, subdomain))
|
|
# two rows for subdomain in txt table
|
|
cur.execute("INSERT INTO txt (Subdomain, LastUpdate) VALUES (?, 0)", (subdomain,))
|
|
cur.execute("INSERT INTO txt (Subdomain, LastUpdate) VALUES (?, 0)", (subdomain,))
|
|
con.commit()
|
|
return {
|
|
"allowfrom": [],
|
|
"fulldomain": f"{subdomain}.{domain}",
|
|
"password": password,
|
|
"subdomain": subdomain,
|
|
"username": username
|
|
}
|
|
|
|
@post('/update')
|
|
def update():
|
|
body = request.body.read()
|
|
req_username = request.headers.get('X-Api-User')
|
|
req_password = request.headers.get('X-Api-Key')
|
|
try:
|
|
data = json.loads(body)
|
|
except json.decoder.JSONDecodeError:
|
|
response.status = 400
|
|
return { "error": "Invalid JSON" }
|
|
req_subdomain = data['subdomain']
|
|
req_txt = data['txt']
|
|
res = cur.execute("SELECT Password, Subdomain FROM records WHERE Username=? LIMIT 1", (req_username,))
|
|
record = res.fetchone()
|
|
if record is None:
|
|
response.status = 403
|
|
return { "error": "Invalid username" }
|
|
(db_password, db_subdomain) = record
|
|
if req_subdomain != db_subdomain:
|
|
response.status = 403
|
|
return { "error": "Invalid subdomain" }
|
|
if not bcrypt.verify(req_password, db_password):
|
|
response.status = 403
|
|
return { "error": "Invalid password" }
|
|
if not is_valid_token(req_txt):
|
|
response.status = 400
|
|
return { "error": "Invalid challenge token" }
|
|
set_txt(req_subdomain, req_txt)
|
|
write_zonefile()
|
|
reload_dns_server()
|
|
return {
|
|
"txt": req_txt
|
|
}
|
|
|
|
@get('/health')
|
|
def health():
|
|
response.status = 200
|
|
|
|
@get('/reload')
|
|
def reload():
|
|
if disable_reload:
|
|
response.status = 403
|
|
return { "error": "Reload endpoint disabled" }
|
|
write_zonefile()
|
|
reload_dns_server()
|
|
|
|
@get('/')
|
|
def source():
|
|
response.content_type = 'text/plain; charset=utf-8'
|
|
return source_str
|
|
|
|
if __name__ == "__main__":
|
|
run(host=bind_host, port=bind_port)
|
|
else:
|
|
app = application = default_app()
|