2023-08-21 16:02:05 +03:00
|
|
|
''' Utility: Graph implementation. '''
|
2024-04-26 00:38:22 +03:00
|
|
|
import copy
|
|
|
|
from typing import Generic, Iterable, Optional, TypeVar
|
2023-08-21 16:02:05 +03:00
|
|
|
|
2024-04-26 00:38:22 +03:00
|
|
|
ItemType = TypeVar("ItemType")
|
|
|
|
|
|
|
|
|
|
|
|
class Graph(Generic[ItemType]):
|
2023-08-21 16:02:05 +03:00
|
|
|
''' Directed graph. '''
|
2024-05-24 19:06:39 +03:00
|
|
|
|
|
|
|
def __init__(self, graph: Optional[dict[ItemType, list[ItemType]]] = None):
|
2023-08-21 16:02:05 +03:00
|
|
|
if graph is None:
|
2024-04-26 00:38:22 +03:00
|
|
|
self.outputs: dict[ItemType, list[ItemType]] = {}
|
|
|
|
self.inputs: dict[ItemType, list[ItemType]] = {}
|
2023-08-21 16:02:05 +03:00
|
|
|
else:
|
2024-04-26 00:38:22 +03:00
|
|
|
self.outputs = graph
|
2024-05-24 19:06:39 +03:00
|
|
|
self.inputs: dict[ItemType, list[ItemType]] = {id: [] for id in graph.keys()} # type: ignore[no-redef]
|
2024-04-26 00:38:22 +03:00
|
|
|
for parent in graph.keys():
|
|
|
|
for child in graph[parent]:
|
|
|
|
self.inputs[child].append(parent)
|
2023-08-21 16:02:05 +03:00
|
|
|
|
2024-04-26 00:38:22 +03:00
|
|
|
def contains(self, node_id: ItemType) -> bool:
|
2023-08-21 16:02:05 +03:00
|
|
|
''' Check if node is in graph. '''
|
2024-04-26 00:38:22 +03:00
|
|
|
return node_id in self.outputs
|
2023-08-21 20:20:03 +03:00
|
|
|
|
2024-04-26 00:38:22 +03:00
|
|
|
def has_edge(self, src: ItemType, dest: ItemType) -> bool:
|
2023-08-21 16:02:05 +03:00
|
|
|
''' Check if edge is in graph. '''
|
2024-04-26 00:38:22 +03:00
|
|
|
return self.contains(src) and dest in self.outputs[src]
|
2023-08-21 16:02:05 +03:00
|
|
|
|
2024-04-26 00:38:22 +03:00
|
|
|
def add_node(self, node_id: ItemType):
|
2023-08-21 16:02:05 +03:00
|
|
|
''' Add node to graph. '''
|
|
|
|
if not self.contains(node_id):
|
2024-04-26 00:38:22 +03:00
|
|
|
self.outputs[node_id] = []
|
|
|
|
self.inputs[node_id] = []
|
2023-08-21 16:02:05 +03:00
|
|
|
|
2024-04-26 00:38:22 +03:00
|
|
|
def add_edge(self, src: ItemType, dest: ItemType):
|
2023-08-21 16:02:05 +03:00
|
|
|
''' Add edge to graph. '''
|
2024-04-26 00:38:22 +03:00
|
|
|
self.add_node(src)
|
|
|
|
self.add_node(dest)
|
|
|
|
if dest not in self.outputs[src]:
|
|
|
|
self.outputs[src].append(dest)
|
|
|
|
if src not in self.inputs[dest]:
|
|
|
|
self.inputs[dest].append(src)
|
|
|
|
|
2024-08-16 20:57:37 +03:00
|
|
|
def remove_edge(self, src: ItemType, dest: ItemType):
|
|
|
|
''' Remove edge from graph. '''
|
|
|
|
if not self.contains(src) or not self.contains(dest):
|
|
|
|
return
|
|
|
|
if dest in self.outputs[src]:
|
|
|
|
self.outputs[src].remove(dest)
|
|
|
|
if src in self.inputs[dest]:
|
|
|
|
self.inputs[dest].remove(src)
|
|
|
|
|
|
|
|
def remove_node(self, target: ItemType):
|
|
|
|
''' Remove node from graph. '''
|
|
|
|
if not self.contains(target):
|
|
|
|
return
|
|
|
|
del self.outputs[target]
|
|
|
|
del self.inputs[target]
|
|
|
|
for list_out in self.outputs.values():
|
|
|
|
if target in list_out:
|
|
|
|
list_out.remove(target)
|
|
|
|
for list_in in self.inputs.values():
|
|
|
|
if target in list_in:
|
|
|
|
list_in.remove(target)
|
|
|
|
|
2024-04-26 00:38:22 +03:00
|
|
|
def expand_inputs(self, origin: Iterable[ItemType]) -> list[ItemType]:
|
|
|
|
''' Expand origin nodes forward through graph edges. '''
|
|
|
|
result: list[ItemType] = []
|
|
|
|
marked: set[ItemType] = set(origin)
|
|
|
|
for node_id in origin:
|
|
|
|
if self.contains(node_id):
|
|
|
|
for child_id in self.inputs[node_id]:
|
|
|
|
if child_id not in marked and child_id not in result:
|
|
|
|
result.append(child_id)
|
|
|
|
position: int = 0
|
|
|
|
while position < len(result):
|
|
|
|
node_id = result[position]
|
|
|
|
position += 1
|
|
|
|
if node_id not in marked:
|
|
|
|
marked.add(node_id)
|
|
|
|
for child_id in self.inputs[node_id]:
|
|
|
|
if child_id not in marked and child_id not in result:
|
|
|
|
result.append(child_id)
|
|
|
|
return result
|
2023-08-21 16:02:05 +03:00
|
|
|
|
2024-04-26 00:38:22 +03:00
|
|
|
def expand_outputs(self, origin: Iterable[ItemType]) -> list[ItemType]:
|
2023-08-21 16:02:05 +03:00
|
|
|
''' Expand origin nodes forward through graph edges. '''
|
2024-04-26 00:38:22 +03:00
|
|
|
result: list[ItemType] = []
|
|
|
|
marked: set[ItemType] = set(origin)
|
2023-08-21 16:02:05 +03:00
|
|
|
for node_id in origin:
|
|
|
|
if self.contains(node_id):
|
2024-04-26 00:38:22 +03:00
|
|
|
for child_id in self.outputs[node_id]:
|
2023-08-21 16:02:05 +03:00
|
|
|
if child_id not in marked and child_id not in result:
|
|
|
|
result.append(child_id)
|
|
|
|
position: int = 0
|
|
|
|
while position < len(result):
|
|
|
|
node_id = result[position]
|
|
|
|
position += 1
|
2023-08-21 20:20:03 +03:00
|
|
|
if node_id not in marked:
|
2023-08-21 16:02:05 +03:00
|
|
|
marked.add(node_id)
|
2024-04-26 00:38:22 +03:00
|
|
|
for child_id in self.outputs[node_id]:
|
2023-08-21 16:02:05 +03:00
|
|
|
if child_id not in marked and child_id not in result:
|
|
|
|
result.append(child_id)
|
|
|
|
return result
|
|
|
|
|
2024-04-26 00:38:22 +03:00
|
|
|
def transitive_closure(self) -> dict[ItemType, list[ItemType]]:
|
|
|
|
''' Generate transitive closure - list of reachable nodes for each node. '''
|
|
|
|
result = copy.deepcopy(self.outputs)
|
|
|
|
order = self.topological_order()
|
|
|
|
order.reverse()
|
|
|
|
for node_id in order:
|
|
|
|
if len(self.inputs[node_id]) == 0:
|
|
|
|
continue
|
|
|
|
for parent in self.inputs[node_id]:
|
2024-06-23 14:20:52 +03:00
|
|
|
result[parent] = result[parent] + [id for id in result[node_id] if id not in result[parent]]
|
2024-04-26 00:38:22 +03:00
|
|
|
return result
|
|
|
|
|
|
|
|
def topological_order(self) -> list[ItemType]:
|
|
|
|
''' Return nodes in SOME topological order. '''
|
|
|
|
result: list[ItemType] = []
|
|
|
|
marked: set[ItemType] = set()
|
|
|
|
for node_id in self.outputs.keys():
|
2023-08-21 20:20:03 +03:00
|
|
|
if node_id in marked:
|
|
|
|
continue
|
2024-04-26 00:38:22 +03:00
|
|
|
to_visit: list[ItemType] = [node_id]
|
2023-08-21 20:20:03 +03:00
|
|
|
while len(to_visit) > 0:
|
|
|
|
node = to_visit[-1]
|
|
|
|
if node in marked:
|
|
|
|
if node not in result:
|
|
|
|
result.append(node)
|
|
|
|
to_visit.remove(node)
|
|
|
|
else:
|
|
|
|
marked.add(node)
|
2024-04-26 00:38:22 +03:00
|
|
|
if len(self.outputs[node]) <= 0:
|
2023-08-21 20:20:03 +03:00
|
|
|
continue
|
2024-04-26 00:38:22 +03:00
|
|
|
for child_id in self.outputs[node]:
|
2023-08-21 20:20:03 +03:00
|
|
|
if child_id not in marked:
|
|
|
|
to_visit.append(child_id)
|
2023-08-21 16:02:05 +03:00
|
|
|
result.reverse()
|
|
|
|
return result
|
2024-04-26 00:38:22 +03:00
|
|
|
|
|
|
|
def sort_stable(self, target: list[ItemType]) -> list[ItemType]:
|
|
|
|
''' Returns target stable sorted in topological order based on minimal modifications. '''
|
|
|
|
if len(target) <= 1:
|
|
|
|
return target
|
|
|
|
reachable = self.transitive_closure()
|
|
|
|
test_set: set[ItemType] = set()
|
|
|
|
result: list[ItemType] = []
|
|
|
|
for node_id in reversed(target):
|
|
|
|
need_move = node_id in test_set
|
|
|
|
test_set = test_set.union(reachable[node_id])
|
|
|
|
if not need_move:
|
|
|
|
result.append(node_id)
|
|
|
|
continue
|
|
|
|
for (index, parent) in enumerate(result):
|
|
|
|
if node_id in reachable[parent]:
|
|
|
|
if parent in reachable[node_id]:
|
|
|
|
result.append(node_id)
|
|
|
|
else:
|
|
|
|
result.insert(index, node_id)
|
|
|
|
break
|
|
|
|
result.reverse()
|
|
|
|
return result
|