#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CSV到MD的实时同步脚本
监控CSV文件变化，自动更新publications.md中的tbody内容
"""

import argparse
import csv
import html
import os
import re
import time
from pathlib import Path
from typing import List, Tuple, Optional


TABLE_ID = 'publications-table'
ROW_STYLE = 'style="padding:10px;border-right:none;border-bottom:1px solid #eee;"'


def normalize_doi(raw: str) -> str:
    """标准化DOI链接"""
    if not raw:
        return ''
    raw = raw.strip()
    if raw.lower().startswith('http://') or raw.lower().startswith('https://'):
        return raw
    # 去掉 doi: 或 DOI: 前缀
    raw = re.sub(r'^(?i:doi)\s*[:/\s]*', '', raw)
    if raw.startswith('10.'):
        return f'https://doi.org/{raw}'
    return raw


def parse_interpretation(interpretation: str) -> Tuple[str, str]:
    """
    解析文章解读字段，提取导语和链接
    格式：导语内容 | 链接URL
    """
    if not interpretation or not interpretation.strip():
        return '', ''
    
    # 按 " | " 分割
    parts = interpretation.split(' | ', 1)
    if len(parts) == 2:
        intro = parts[0].strip()
        link = parts[1].strip()
        return intro, link
    else:
        # 没有分隔符，检查是否包含链接
        text = interpretation.strip()
        if text.startswith('http'):
            return '', text
        else:
            return text, ''


def build_row(year: str, title: str, journal: str, impact: str, sample: str, product: str, doi: str, interpretation: str = '') -> str:
    """构建表格行HTML"""
    year_h = html.escape(year)
    title_h = html.escape(title)
    journal_h = html.escape(journal)
    impact_h = html.escape(impact)
    sample_h = html.escape(sample)
    product_h = html.escape(product)

    # 处理DOI
    href = normalize_doi(doi)
    doi_label = doi.strip() if doi else ''
    if href and not doi_label:
        doi_label = href.replace('https://', '').replace('http://', '')
    doi_h = html.escape(doi_label)
    if href:
        doi_cell = f'<a href="{html.escape(href)}" target="_blank">{doi_h}</a>'
    else:
        doi_cell = doi_h

    # 处理文章解读
    intro, link = parse_interpretation(interpretation)
    interpretation_cell = ''
    
    if intro or link:
        interpretation_cell = '          <td ' + ROW_STYLE + '>\n'
        if intro:
            # 支持富文本，直接插入HTML
            interpretation_cell += f'            <div data-intro-html>{intro}</div>\n'
        if link:
            interpretation_cell += f'            <a href="{html.escape(link)}" target="_blank">查看解读</a>\n'
        interpretation_cell += '          </td>'
    else:
        interpretation_cell = f'          <td {ROW_STYLE}></td>'

    return (
        '        <tr>\n'
        f'          <td {ROW_STYLE}>{year_h}</td>\n'
        f'          <td {ROW_STYLE}>{title_h}</td>\n'
        f'          <td {ROW_STYLE}>{journal_h}</td>\n'
        f'          <td {ROW_STYLE}>{impact_h}</td>\n'
        f'          <td {ROW_STYLE}>{sample_h}</td>\n'
        f'          <td {ROW_STYLE}>{product_h}</td>\n'
        f'          <td {ROW_STYLE}>{doi_cell}</td>\n'
        f'{interpretation_cell}\n'
        '        </tr>'
    )


def read_csv_rows(csv_path: str) -> List[List[str]]:
    """读取CSV文件，返回行数据"""
    rows = []
    try:
        with open(csv_path, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            for row in reader:
                # 确保每行有8列（年份,文章标题,期刊,影响因子,样品类型,产品类型,DOI,文章解读）
                while len(row) < 8:
                    row.append('')
                rows.append(row)
    except Exception as e:
        print(f"读取CSV文件失败: {e}")
        return []
    
    # 跳过表头
    return rows[1:] if len(rows) > 1 else []


def extract_tbody_blocks(md_text: str) -> Tuple[int, int]:
    """找到tbody的开始和结束位置"""
    table_idx = md_text.find(f'<table id="{TABLE_ID}"')
    if table_idx == -1:
        raise ValueError(f'未找到 id="{TABLE_ID}" 的表格')
    
    tbody_start = md_text.find('<tbody>', table_idx)
    tbody_end = md_text.find('</tbody>', tbody_start)
    
    if tbody_start == -1 or tbody_end == -1:
        raise ValueError('未找到 <tbody> 或 </tbody> 标签')
    
    return tbody_start + len('<tbody>'), tbody_end


def update_md_from_csv(csv_path: str, md_path: str, encoding: str = 'utf-8') -> bool:
    """从CSV更新MD文件"""
    try:
        # 读取CSV数据
        csv_rows = read_csv_rows(csv_path)
        if not csv_rows:
            print("CSV文件为空或读取失败")
            return False

        # 构建HTML行
        html_rows = []
        for row in csv_rows:
            if len(row) >= 8:
                year, title, journal, impact, sample, product, doi, interpretation = row[:8]
                html_row = build_row(year, title, journal, impact, sample, product, doi, interpretation)
                html_rows.append(html_row)

        # 读取MD文件
        with open(md_path, 'r', encoding=encoding) as f:
            md_text = f.read()

        # 找到tbody位置并替换
        start, end = extract_tbody_blocks(md_text)
        new_tbody_inner = '\n' + '\n'.join(html_rows) + '\n      '
        new_md = md_text[:start] + new_tbody_inner + md_text[end:]

        # 写回MD文件
        with open(md_path, 'w', encoding=encoding, newline='') as f:
            f.write(new_md)

        print(f"已更新 {len(html_rows)} 条文章记录到 {md_path}")
        return True

    except Exception as e:
        print(f"更新失败: {e}")
        return False


def watch_csv_file(csv_path: str, md_path: str, encoding: str = 'utf-8', interval: float = 1.0):
    """监控CSV文件变化"""
    csv_file = Path(csv_path)
    md_file = Path(md_path)
    
    if not csv_file.exists():
        print(f"CSV文件不存在: {csv_path}")
        return
    
    if not md_file.exists():
        print(f"MD文件不存在: {md_path}")
        return

    print(f"开始监控 {csv_path}")
    print(f"目标文件 {md_path}")
    print(f"检查间隔 {interval}秒")
    print("按 Ctrl+C 停止监控")
    
    last_mtime = csv_file.stat().st_mtime
    
    try:
        while True:
            try:
                current_mtime = csv_file.stat().st_mtime
                if current_mtime > last_mtime:
                    print(f"\n检测到CSV文件变化 ({time.strftime('%H:%M:%S')})")
                    print(f"文件修改时间: {current_mtime} > {last_mtime}")
                    if update_md_from_csv(csv_path, md_path, encoding):
                        last_mtime = current_mtime
                        print("更新成功，继续监控...")
                    else:
                        print("更新失败，继续监控...")
                else:
                    # 每10次检查输出一次状态（避免日志过多）
                    if int(time.time()) % 10 == 0:
                        print(f"监控中... 当前时间: {time.strftime('%H:%M:%S')}")
            except Exception as e:
                print(f"监控过程中出现错误: {e}")
                time.sleep(interval)
                continue
            
            time.sleep(interval)
            
    except KeyboardInterrupt:
        print("\n监控已停止")


def main():
    parser = argparse.ArgumentParser(description='CSV到MD的实时同步工具')
    parser.add_argument('--csv', required=True, help='CSV文件路径')
    parser.add_argument('--md', required=True, help='MD文件路径')
    parser.add_argument('--encoding', default='utf-8', help='文件编码，默认utf-8')
    parser.add_argument('--watch', action='store_true', help='监控模式，实时同步')
    parser.add_argument('--interval', type=float, default=1.0, help='监控间隔（秒），默认1.0')
    
    args = parser.parse_args()
    
    if args.watch:
        watch_csv_file(args.csv, args.md, args.encoding, args.interval)
    else:
        # 单次同步
        if update_md_from_csv(args.csv, args.md, args.encoding):
            print("同步完成")
        else:
            print("同步失败")


if __name__ == '__main__':
    main()
