流程

先写一个简单的寻找危险函数的AST然后打补丁,提取出这个危险函数的参数,通过正则来判断这个参数的师傅是拼接或者直接通过AST来判断,对可以直接通过AST判断,当参数全为常数时没有当参数中出现了变量时其存在。

然后再打补丁来判断这个参数是否可控,目前没有看到怎么通过AST来向上寻找参数,个人目前的想法是通过,将所有可控参数都进行记录,当发现了$_GET,$_POST这种时将其记录为一条赋值链,结构是双头链表或者其他可以向上查询的,这样子当再危险函数中发现了参数时只要查找是否再链中就可以确定其是否存在sql注入。

AST

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from phply import phplex
from phply.phpparse import make_parser
from phply import phpast as php

class PHPASTAuditor:
def __init__(self):
self.lexer = phplex.lexer.clone()
self.parser = make_parser()
self.danger_functions = [
'mysql_query', 'mysqli_query', 'pg_query',
'eval', 'exec', 'shell_exec', 'system', 'passthru'
]
self.sql_keywords = ['select', 'insert', 'update', 'delete', 'replace', 'drop']

def parse_php_code(self, code):
return self.parser.parse(code, lexer=self.lexer)

def get_func_name(self, node):
"""获取函数名,如果是字符串就返回"""
if isinstance(node, php.FunctionCall):
if hasattr(node.name, 'name'):
return node.name.name
elif isinstance(node.name, str):
return node.name
return None

def is_danger_function(self, node):
name = self.get_func_name(node)
return name in self.danger_functions if name else False

def contains_user_variable(self, node):
"""
检查节点是否包含变量(非常量)
支持:
- Parameter 节点
- BinaryOp 拼接
- 单个变量
"""
# 如果节点是 Parameter,取内部节点
if node.__class__.__name__ == 'Parameter' and hasattr(node, 'node'):
node = node.node

# 节点是变量
if node.__class__.__name__ == 'Variable':
return True

# BinaryOp 拼接
if hasattr(node, 'left') and hasattr(node, 'right') and getattr(node, 'op', '') == '.':
return self.contains_user_variable(node.left) or self.contains_user_variable(node.right)

# 其他类型不算变量
return False

def audit_node(self, node):
findings = []

# 检查危险函数调用
if self.is_danger_function(node):
name = self.get_func_name(node)
findings.append(f"Dangerous function call: {name} at line {getattr(node, 'lineno', 'unknown')}")

# 检查 SQL 注入风险
func_name = self.get_func_name(node)
if func_name in ['mysql_query', 'mysqli_query', 'pg_query']:
for arg in getattr(node, 'params', []):
if self.contains_user_variable(arg):
findings.append(f"Potential SQL usage in {func_name} at line {getattr(node, 'lineno', 'unknown')}")

# 递归遍历子节点
for attr_name, attr_value in node.__dict__.items():
if isinstance(attr_value, list):
for item in attr_value:
if isinstance(item, php.Node):
findings.extend(self.audit_node(item))
elif isinstance(attr_value, php.Node):
findings.extend(self.audit_node(attr_value))

return findings

def audit_php_code(self, code):
try:
ast_nodes = self.parse_php_code(code)
print(ast_nodes)
all_findings = []
for node in ast_nodes:
all_findings.extend(self.audit_node(node))
return all_findings
except Exception as e:
return [f"Error parsing code: {e}"]


import os

def get_php_files(folder_path):
"""
遍历文件夹,获取所有 .php 文件路径
"""
php_files = []
for root, _, files in os.walk(folder_path):
for file in files:
if file.endswith(".php"):
php_files.append(os.path.join(root, file))
return php_files

def audit_folder(folder_path, auditor):
"""
扫描整个文件夹的 PHP 文件
"""
all_findings = {}
php_files = get_php_files(folder_path)

for php_file in php_files:
try:
with open(php_file, "r", encoding="utf-8", errors="ignore") as f:
php_code = f.read()

findings = auditor.audit_php_code(php_code)
if findings:
all_findings[php_file] = findings
except Exception as e:
all_findings[php_file] = [f"Error reading file: {e}"]

return all_findings

if __name__ == "__main__":
folder = "./" # 输入要扫描的文件夹
auditor = PHPASTAuditor()

results = audit_folder(folder, auditor)

for file, findings in results.items():
print(f"\n[+] File: {file}")
for finding in findings:
print(f" - {finding}")

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
from phply import phplex
from phply.phpparse import make_parser
from phply import phpast as php


class PHPASTAuditor:
def __init__(self):
self._init_parser()
self.danger_functions = [
'mysql_query', 'mysqli_query', 'pg_query',
'execute', 'eval', 'exec', 'shell_exec',
'system', 'passthru', 'query'
]
self.safe_functions = [
'mysql_real_escape_string', 'addslashes',
'htmlspecialchars', 'intval', 'mysqli_real_escape_string'
]
self.taint_map = {}
self.visited_files = set()

def _init_parser(self):
"""重新初始化 lexer 和 parser,保证每次解析从行 1 开始"""
self.lexer = phplex.lexer.clone()
self.parser = make_parser()

# ---------------- 基础工具 ----------------
def parse_php_code(self, code):
return self.parser.parse(code, lexer=self.lexer)

def tag_ast_file(self, node, filename):
if not isinstance(node, php.Node):
return
setattr(node, '__file__', filename)
for child in self.safe_iter_node(node):
self.tag_ast_file(child, filename)

def safe_iter_node(self, node):
if not isinstance(node, php.Node):
return
for _, value in getattr(node, '__dict__', {}).items():
if isinstance(value, php.Node):
yield value
elif isinstance(value, list):
for item in value:
if isinstance(item, php.Node):
yield item

def get_func_name(self, node):
if isinstance(node, php.FunctionCall):
if hasattr(node.name, 'name'):
return node.name.name
elif isinstance(node.name, str):
return node.name
return None

def is_user_input(self, node):
if isinstance(node, php.ArrayOffset) and isinstance(node.node, php.Variable):
return node.node.name in ['$_GET', '$_POST', '$_REQUEST', '$_COOKIE']
if isinstance(node, php.Variable):
return node.name in ['$_GET', '$_POST', '$_REQUEST', '$_COOKIE']
return False

def mark_taint(self, var_name):
self.taint_map[var_name] = True

def is_tainted(self, var_name):
return self.taint_map.get(var_name, False)

# ---------------- 污点分析 ----------------
def contains_user_input_or_taint(self, node):
if isinstance(node, php.Parameter):
return self.contains_user_input_or_taint(node.node)
if isinstance(node, php.Variable):
return self.is_user_input(node) or self.is_tainted(node.name)
if isinstance(node, php.ArrayOffset):
return self.is_user_input(node) or self.contains_user_input_or_taint(node.node)
if isinstance(node, php.BinaryOp):
return (self.contains_user_input_or_taint(node.left)
or self.contains_user_input_or_taint(node.right))
if isinstance(node, php.FunctionCall):
func_name = self.get_func_name(node)
if func_name in self.safe_functions:
return False
return any(self.contains_user_input_or_taint(p) for p in getattr(node, 'params', []))
if isinstance(node, php.Node):
for child in self.safe_iter_node(node):
if self.contains_user_input_or_taint(child):
return True
return False

def collect_taints_iterative(self, ast_nodes):
"""多轮迭代传播污点,直到无新增变量"""
changed = True
while changed:
changed = False
stack = list(ast_nodes)
while stack:
node = stack.pop()
if isinstance(node, php.Assignment) and isinstance(node.node, php.Variable):
var_name = node.node.name
if self.contains_user_input_or_taint(node.expr) and not self.is_tainted(var_name):
self.mark_taint(var_name)
changed = True
for child in self.safe_iter_node(node):
stack.append(child)

# ---------------- 危险函数检查 ----------------
def check_danger_functions(self, node):
findings = []
if not isinstance(node, php.Node):
return findings
func_name = self.get_func_name(node)
if func_name in self.danger_functions:
for arg in getattr(node, 'params', []):
if self.contains_user_input_or_taint(arg):
file = getattr(node, '__file__', 'unknown')
line = getattr(node, 'lineno', 'unknown')
findings.append(
f"[TAINT] Dangerous function '{func_name}' in {file} at line {line}"
)
for child in self.safe_iter_node(node):
findings.extend(self.check_danger_functions(child))
return findings

# ---------------- include/require ----------------
def process_include(self, node, current_dir):
expr = node.expr
included_file = None
if isinstance(expr, php.Scalar) and isinstance(expr.value, str):
included_file = expr.value
if included_file:
included_path = os.path.abspath(os.path.join(current_dir, included_file))
if os.path.exists(included_path) and included_path not in self.visited_files:
self.visited_files.add(included_path)
try:
self._init_parser() # 每次 include 文件重新初始化解析器
with open(included_path, "r", encoding="utf-8", errors="ignore") as f:
code = f.read()
ast_nodes = self.parse_php_code(code)
for n in ast_nodes:
self.tag_ast_file(n, included_path)
self.collect_taints_iterative(ast_nodes)
for n in ast_nodes:
for child in self.find_include_nodes(n):
self.process_include(child, os.path.dirname(included_path))
except Exception as e:
print(f"[WARN] Failed to parse include file {included_path}: {e}")

def find_include_nodes(self, node):
nodes = []
if isinstance(node, php.Include) or isinstance(node, php.Require):
nodes.append(node)
for child in self.safe_iter_node(node):
nodes.extend(self.find_include_nodes(child))
return nodes

# ---------------- 主扫描入口 ----------------
def audit_php_file(self, filepath):
self._init_parser() # 每次扫描文件时重置解析状态
self.taint_map.clear()
self.visited_files.clear()

abs_path = os.path.abspath(filepath)
self.visited_files.add(abs_path)
try:
with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
code = f.read()
ast_nodes = self.parse_php_code(code)
for n in ast_nodes:
self.tag_ast_file(n, abs_path)

# include 文件
for n in ast_nodes:
for inc in self.find_include_nodes(n):
self.process_include(inc, os.path.dirname(filepath))

# 收集污点
self.collect_taints_iterative(ast_nodes)

# 检查危险函数
findings = []
for n in ast_nodes:
findings.extend(self.check_danger_functions(n))

return findings, list(self.taint_map.keys())
except Exception as e:
return [f"Error parsing file: {e}"], []






main.py
from php_auditor import PHPASTAuditor
import os
def scan_folder(folder):
auditor = PHPASTAuditor()
all_results = {}
for root, _, files in os.walk(folder):
for file in files:
if file.endswith(".php"):
path = os.path.join(root, file)
findings, taints = auditor.audit_php_file(path)
all_results[path] = {"taints": taints, "findings": findings}
return all_results
folder = "./" # 当前目录
results = scan_folder(folder)
for file, info in results.items():
print(f"\n[+] File: {file}")
print("Collected taint variables:", info["taints"])
for finding in info["findings"]:
print(finding)