import os, base64, hashlib, secrets, urllib.parse, requests
from http.server import BaseHTTPRequestHandler, HTTPServer
from dotenv import load_dotenv

load_dotenv(".env")

CLIENT_ID = os.getenv("X_CLIENT_ID","").strip()
CLIENT_SECRET = os.getenv("X_CLIENT_SECRET","").strip()
REDIRECT_URI = os.getenv("X_REDIRECT_URI","http://127.0.0.1:8088/callback").strip()
SCOPES = os.getenv("X_SCOPES","tweet.read tweet.write users.read offline.access").strip()

if not CLIENT_ID or not CLIENT_SECRET:
    raise SystemExit("Missing X_CLIENT_ID or X_CLIENT_SECRET in .env")

def b64url(b: bytes) -> str:
    return base64.urlsafe_b64encode(b).decode().rstrip("=")

code_verifier = b64url(secrets.token_bytes(48))
code_challenge = b64url(hashlib.sha256(code_verifier.encode()).digest())
state = b64url(secrets.token_bytes(16))

params = {
    "response_type": "code",
    "client_id": CLIENT_ID,
    "redirect_uri": REDIRECT_URI,
    "scope": SCOPES,
    "state": state,
    "code_challenge": code_challenge,
    "code_challenge_method": "S256",
}
auth_url = "https://x.com/i/oauth2/authorize?" + urllib.parse.urlencode(params)
print(auth_url)

result = {"code": None, "state": None, "error": None}

class H(BaseHTTPRequestHandler):
    def do_GET(self):
        p = urllib.parse.urlparse(self.path)
        if p.path == "/favicon.ico":
            self.send_response(204); self.end_headers(); return

        qs = urllib.parse.parse_qs(p.query)
        result["error"] = (qs.get("error") or [None])[0] or result["error"]

        code = (qs.get("code") or [None])[0]
        st = (qs.get("state") or [None])[0]
        if code and st:
            result["code"] = code
            result["state"] = st

        self.send_response(200)
        self.send_header("Content-Type","text/plain; charset=utf-8")
        self.end_headers()
        self.wfile.write(b"ok")

    def log_message(self, *args): pass

u = urllib.parse.urlparse(REDIRECT_URI)
host = u.hostname or "127.0.0.1"
port = u.port or 8088

httpd = HTTPServer((host, port), H)
while result["code"] is None:
    httpd.handle_request()
    if result["error"]:
        raise SystemExit(f"Authorization error: {result['error']}")

if result["state"] != state:
    raise SystemExit("Authorization failed (state/code mismatch)")

token_url = "https://api.x.com/2/oauth2/token"
headers = {"Content-Type": "application/x-www-form-urlencoded"}

data = {
    "grant_type": "authorization_code",
    "client_id": CLIENT_ID,
    "code": result["code"],
    "redirect_uri": REDIRECT_URI,
    "code_verifier": code_verifier,
}

# ✅ 关键：用 requests 的 auth 参数生成 Basic Authorization
r = requests.post(token_url, data=data, headers=headers, auth=(CLIENT_ID, CLIENT_SECRET), timeout=30)
print("TOKEN_HTTP:", r.status_code)
print("TOKEN_BODY:", r.text)

if r.status_code != 200:
    raise SystemExit("Token exchange failed (see TOKEN_BODY above)")

j = r.json()
refresh = j.get("refresh_token")
if not refresh:
    raise SystemExit("No refresh_token. Ensure offline.access scope is enabled.")

# 写回 .env（替换或追加）
env_path = ".env"
lines = open(env_path,"r",encoding="utf-8").read().splitlines() if os.path.exists(env_path) else []
out=[]; found=False
for ln in lines:
    if ln.startswith("X_REFRESH_TOKEN="):
        out.append("X_REFRESH_TOKEN=" + refresh); found=True
    else:
        out.append(ln)
if not found:
    out.append("X_REFRESH_TOKEN=" + refresh)
open(env_path,"w",encoding="utf-8").write("\n".join(out) + "\n")

print("ok")
