sqllineage/demo.py
2025-02-05 14:18:02 +08:00

170 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sqlparse
import sqlglot
from sqlglot.expressions import ColumnDef
def extract_create_table(sql_script):
# 解析 SQL 脚本
parsed = sqlparse.parse(sql_script)
create_table_statements = []
for statement in parsed:
# 关闭格式化选项保持原样
stripped = sqlparse.format(
statement.value,
strip_comments=True,
reindent=False,
keyword_case="lower"
)
# 跳过空语句
if not stripped.strip():
continue
# 可修改条件来匹配其他语句类型
if stripped.upper().strip().startswith(("CREATE TABLE")):
create_table_statements.append(stripped)
return "\n".join(create_table_statements)
# 原始 SQL 脚本
sql_script = """
BEGIN;
/*
DROP TABLE ods.track_log_002;
*/
-- Type: TABLE ; Name: track_log_002; Owner: sdk_statis_developer
CREATE TABLE ods.track_log_002 (
appid bigint NOT NULL,
app_ver text,
sdk_ver text,
channel text,
country text,
province text,
city text,
isp text,
ip text,
device_width integer,
device_height integer,
device_id text NOT NULL,
device_lang text,
device_model text,
device_brand text,
device_os text,
device_type text,
event_name text NOT NULL,
event_type text,
event_time bigint NOT NULL,
net_type text,
user_id text,
order_id text,
amount bigint,
platform text,
status integer,
servid text,
server_name text,
role_id text,
role_name text,
role_level text,
job_id text,
job_name text,
var1 text,
var2 text,
var3 text,
var4 text,
var5 text,
var6 text,
var7 text,
var8 text,
var9 text,
var10 text,
var11 text,
var12 text,
var13 text,
var14 text,
var15 text,
var16 text,
var17 text,
var18 text,
var19 text,
var20 text,
var21 text,
var22 text,
var23 text,
var24 text,
var25 text,
var26 text,
var27 text,
var28 text,
var29 text,
var30 text,
ds text NOT NULL,
prodid text,
prod_name text,
sub_servid text,
sub_server_name text
)
PARTITION BY LIST (ds)with (
orientation = 'column',
storage_format = 'orc',
auto_partitioning_enable = 'true',
auto_partitioning_num_hot = '90',
auto_partitioning_num_precreate = '2',
auto_partitioning_num_retention = '191',
auto_partitioning_schd_start_time = '1970-01-01 00:00:00',
auto_partitioning_time_format = '',
auto_partitioning_time_unit = 'day',
auto_partitioning_time_zone = 'PRC',
bitmap_columns = 'appid,event_name,ds,role_id,device_id,servid,user_id,country,channel,province,status,city,device_width,var4,var3,var2,var1,amount,device_height,var12,var13,var14,var15,var10,var11,var9,var8,var7,var6,var5,event_time',
clustering_key = 'appid:asc',
dictionary_encoding_columns = '',
segment_key = 'event_time',
table_group = 'sdk_statis_tg_s80',
table_storage_mode = 'hot',
time_to_live_in_seconds = '16416000'
);
COMMENT ON TABLE ods.track_log_002 IS NULL;
ALTER TABLE ods.track_log_002 OWNER TO sdk_statis_developer;
END;
"""
# 执行解析
result = extract_create_table(sql_script)
re_create_table_sql = sqlglot.transpile(result, read="postgres", write="hive")[0]
parsed = sqlglot.parse_one(re_create_table_sql, read='hive')
# 获取表名
table_name = parsed.this.this
columns = []
# 遍历所有可能包含列定义的子表达式
for expression in parsed.walk():
if isinstance(expression[0], ColumnDef):
# 获取列名
column_name = expression[0].this.this
# 获取数据类型
column_type = expression[0].args['kind'].this.name.upper()
# 如果是TEXT类型则转换为STRING
if column_type == 'TEXT':
column_type = 'STRING'
columns.append({'name': column_name, 'type': column_type})
# 输出表名和字段信息
print(f"表名称: {table_name}")
# 输出结果
for column in columns:
print(f"字段名称: {column['name']}, 字段类型: {column['type']}")