import sqlparse
import sqlglot
from sqlglot.expressions import ColumnDef


def odps(schema, table_name, columns, colmapping, hologres_connection):
    odps_sql = f'''
CREATE EXTERNAL TABLE IF NOT EXISTS {table_name}
(
{columns}
)
STORED BY 'com.aliyun.odps.jdbc.JdbcStorageHandler'
-- ip设置成经典网络ip  库 加Schema 加表名
location 'jdbc:postgresql://hgprecn-cn-i7m2ssubq004-cn-hangzhou-internal.hologres.aliyuncs.com:80/sdk_statis?ApplicationName=MaxCompute&currentSchema={schema}&preferQueryMode=simple&useSSL=false&table={table_name}/'
TBLPROPERTIES (
  'mcfed.mapreduce.jdbc.driver.class'='org.postgresql.Driver',
  'odps.federation.jdbc.target.db.type'='holo',
-- 格式为:MaxCompute字段1 : "Hologres字段1",MaxCompute字段2 : "Hologres字段2"
'odps.federation.jdbc.colmapping'='{colmapping}'
);
'''
    return odps_sql


def extract_create_table(sql_script):
    # 解析 SQL 脚本
    parsed = sqlparse.parse(sql_script)
    create_table_statements = []

    for statement in parsed:
        # 关闭格式化选项保持原样
        stripped = sqlparse.format(
            statement.value,
            strip_comments=True,
            reindent=False,
            keyword_case="lower"
        )

        # 跳过空语句
        if not stripped.strip():
            continue

        # 可修改条件来匹配其他语句类型
        if stripped.upper().strip().startswith(("CREATE TABLE")):
            create_table_statements.append(stripped)

    return "\n".join(create_table_statements)


def parse_create_table_sql(create_table_sql, hologres_connection):

    result = extract_create_table(create_table_sql)

    re_create_table_sql = sqlglot.transpile(result, read="postgres", write="hive")[0]

    parsed = sqlglot.parse_one(re_create_table_sql, read='hive')

    # 获取表名
    table_name = parsed.this.this

    columns = []
    colmapping = []
    # 遍历所有可能包含列定义的子表达式
    for expression in parsed.walk():
        if isinstance(expression[0], ColumnDef):
            # 获取列名
            column_name = expression[0].this.this
            # 获取数据类型
            column_type = expression[0].args['kind'].this.name.upper()
            # 如果是TEXT类型,则转换为STRING
            if column_type == 'TEXT':
                column_type = 'STRING'
            if column_type == 'DECIMAL':
                column_type = 'DECIMAL(20,8)'
            # if column_type == 'INT':
            #     column_type = 'BIGINT'
            columns.append(column_name + " " + column_type)
            colmapping.append(column_name + ":" + column_name)
    # 将columns,colmapping转换成字符串用,分割
    columns_str = ",\n".join(columns)
    colmapping_str = ",".join(colmapping)
    table_name_str = str(table_name).split('.')[-1]
    schema = str(table_name).split('.')[0]

    return odps(schema, table_name_str, columns_str, colmapping_str, hologres_connection)