from flask import Flask, render_template, request, redirect, url_for, flash, session, abort, jsonify
from flask_sqlalchemy import SQLAlchemy
from datetime import datetime
import os
import random
from dotenv import load_dotenv

load_dotenv()

app = Flask(__name__)
app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'dev_key_for_weedthought')
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///weedthought.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['API_KEY'] = os.environ.get('OPENCLAW_API_KEY', 'change_this_to_a_secure_key')

db = SQLAlchemy(app)

class Story(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    content = db.Column(db.String(300), nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)

    def __repr__(self):
        return f'<Story {self.id}>'

class BannedIP(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    ip_address = db.Column(db.String(50), unique=True, nullable=False)
    reason = db.Column(db.String(200))
    banned_at = db.Column(db.DateTime, default=datetime.utcnow)

with app.app_context():
    db.create_all()

# In-memory storage for spam strikes
IP_STRIKES = {}

# List of trusted IPs that will never be banned
WHITELISTED_IPS = {'127.0.0.1'}  # Add your public IP here if deploying

@app.before_request
def check_banned():
    if request.remote_addr in WHITELISTED_IPS:
        return

    if BannedIP.query.filter_by(ip_address=request.remote_addr).first():
        abort(403, description="Your IP has been banned due to suspicious activity.")

def penalize_ip(ip, reason, weight=1):
    if ip in WHITELISTED_IPS:
        return False

    current_strikes = IP_STRIKES.get(ip, 0) + weight
    IP_STRIKES[ip] = current_strikes
    
    # Ban if strikes exceed threshold (e.g., 3)
    if current_strikes >= 3:
        if not BannedIP.query.filter_by(ip_address=ip).first():
            db.session.add(BannedIP(ip_address=ip, reason=reason))
            db.session.commit()
            # Clear strikes
            IP_STRIKES.pop(ip, None)
        return True
    return False

@app.route('/')
def index():
    stories = Story.query.order_by(Story.created_at.desc()).all()
    
    # Anti-spam: Generate a simple math question
    num1 = random.randint(1, 10)
    num2 = random.randint(1, 10)
    session['captcha_answer'] = num1 + num2
    captcha_question = f"What is {num1} + {num2}?"
    
    # Anti-spam: Record when the form was loaded
    session['form_start_time'] = datetime.utcnow().timestamp()
    
    return render_template('index.html', stories=stories, captcha_question=captcha_question)

@app.route('/post', methods=['POST'])
def post_story():
    current_time = datetime.utcnow().timestamp()
    client_ip = request.remote_addr

    # Anti-spam: Rate limiting (1 post per 60 seconds per session)
    last_post_time = session.get('last_post_time')
    if last_post_time and (current_time - last_post_time) < 60:
        penalize_ip(client_ip, "Rate limit violation")
        flash('Please wait a minute before posting again.')
        return redirect(url_for('index'))

    # Anti-spam: Time-based check (Must take at least 3 seconds)
    start_time = session.get('form_start_time')
    if not start_time or (current_time - start_time) < 3:
        penalize_ip(client_ip, "Posting too fast")
        flash('You are posting too fast. Please take your time.')
        return redirect(url_for('index'))

    # Anti-spam: Honeypot check (requires hidden input named 'website_check' in HTML)
    if request.form.get('website_check'):
        penalize_ip(client_ip, "Honeypot triggered", weight=3) # Instant ban
        return redirect(url_for('index'))

    # Anti-spam check
    captcha_input = request.form.get('captcha')
    if not captcha_input or not captcha_input.isdigit() or int(captcha_input) != session.get('captcha_answer'):
        penalize_ip(client_ip, "Captcha failed")
        flash('Incorrect math answer. Please try again.')
        return redirect(url_for('index'))

    content = request.form.get('content')
    if not content:
        flash('Story cannot be empty!')
        return redirect(url_for('index'))
    
    if len(content) < 140:
        flash(f'Story is too short ({len(content)} chars). It must be at least 140 characters long.')
        return redirect(url_for('index'))
        
    if len(content) > 300:
        flash(f'Story is too long ({len(content)} chars). It must be under 300 characters.')
        return redirect(url_for('index'))

    new_story = Story(content=content)
    db.session.add(new_story)
    db.session.commit()
    session['last_post_time'] = current_time
    flash('Story posted successfully!')
    return redirect(url_for('index'))

# --- OpenClaw API Endpoints ---

@app.route('/api/stories', methods=['GET'])
def api_get_stories():
    """API endpoint for OpenClaw to read recent stories."""
    if request.headers.get('X-API-Key') != app.config['API_KEY']:
        return jsonify({'error': 'Unauthorized'}), 401

    stories = Story.query.order_by(Story.created_at.desc()).limit(10).all()
    return jsonify([{
        'id': s.id,
        'content': s.content,
        'created_at': s.created_at.isoformat()
    } for s in stories])

@app.route('/api/post', methods=['POST'])
def api_post_story():
    """API endpoint for OpenClaw to post stories."""
    if request.headers.get('X-API-Key') != app.config['API_KEY']:
        return jsonify({'error': 'Unauthorized'}), 401

    data = request.get_json()
    content = data.get('content')

    if not content or len(content) < 140 or len(content) > 300:
        return jsonify({'error': 'Content must be between 140 and 300 characters'}), 400

    new_story = Story(content=content)
    db.session.add(new_story)
    db.session.commit()
    return jsonify({'message': 'Story posted successfully', 'id': new_story.id}), 201

@app.route('/story/<int:story_id>')
def view_story(story_id):
    story = Story.query.get_or_404(story_id)
    return render_template('story.html', story=story)

from flask import make_response

@app.route('/robots.txt')
def robots():
    content = "User-agent: *\nAllow: /\nSitemap: " + url_for('sitemap', _external=True)
    response = make_response(content)
    response.headers["Content-Type"] = "text/plain"
    return response

@app.route('/sitemap.xml')
def sitemap():
    base_url = url_for('index', _external=True)
    stories = Story.query.order_by(Story.created_at.desc()).all()
    
    xml = []
    xml.append('<?xml version="1.0" encoding="UTF-8"?>')
    xml.append('<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">')
    
    # Homepage
    xml.append('<url>')
    xml.append(f'<loc>{base_url}</loc>')
    xml.append('<changefreq>daily</changefreq>')
    xml.append('<priority>1.0</priority>')
    xml.append('</url>')
    
    # Stories
    for story in stories:
        url = url_for('view_story', story_id=story.id, _external=True)
        date = story.created_at.strftime('%Y-%m-%d')
        xml.append('<url>')
        xml.append(f'<loc>{url}</loc>')
        xml.append(f'<lastmod>{date}</lastmod>')
        xml.append('<changefreq>never</changefreq>')
        xml.append('<priority>0.8</priority>')
        xml.append('</url>')
        
    xml.append('</urlset>')
    
    response = make_response("".join(xml))
    response.headers["Content-Type"] = "application/xml"
    return response

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=5001)
