mirror of
https://github.com/IRBorisov/ConceptPortal.git
synced 2025-06-26 21:10:38 +03:00
76 lines
2.7 KiB
Python
76 lines
2.7 KiB
Python
''' Utility: Graph implementation. '''
|
|
from typing import Dict, Iterable, Optional, cast
|
|
|
|
|
|
class Graph:
|
|
''' Directed graph. '''
|
|
def __init__(self, graph: Optional[Dict[str, list[str]]]=None):
|
|
if graph is None:
|
|
self._graph = cast(Dict[str, list[str]], {})
|
|
else:
|
|
self._graph = graph
|
|
|
|
def contains(self, node_id: str) -> bool:
|
|
''' Check if node is in graph. '''
|
|
return node_id in self._graph
|
|
|
|
def has_edge(self, id_from: str, id_to: str) -> bool:
|
|
''' Check if edge is in graph. '''
|
|
return self.contains(id_from) and id_to in self._graph[id_from]
|
|
|
|
def add_node(self, node_id: str):
|
|
''' Add node to graph. '''
|
|
if not self.contains(node_id):
|
|
self._graph[node_id] = []
|
|
|
|
def add_edge(self, id_from: str, id_to: str):
|
|
''' Add edge to graph. '''
|
|
self.add_node(id_from)
|
|
self.add_node(id_to)
|
|
if id_to not in self._graph[id_from]:
|
|
self._graph[id_from].append(id_to)
|
|
|
|
def expand_outputs(self, origin: Iterable[str]) -> list[str]:
|
|
''' Expand origin nodes forward through graph edges. '''
|
|
result: list[str] = []
|
|
marked: set[str] = set(origin)
|
|
for node_id in origin:
|
|
if self.contains(node_id):
|
|
for child_id in self._graph[node_id]:
|
|
if child_id not in marked and child_id not in result:
|
|
result.append(child_id)
|
|
position: int = 0
|
|
while position < len(result):
|
|
node_id = result[position]
|
|
position += 1
|
|
if node_id not in marked:
|
|
marked.add(node_id)
|
|
for child_id in self._graph[node_id]:
|
|
if child_id not in marked and child_id not in result:
|
|
result.append(child_id)
|
|
return result
|
|
|
|
def topological_order(self) -> list[str]:
|
|
''' Return nodes in topological order. '''
|
|
result: list[str] = []
|
|
marked: set[str] = set()
|
|
for node_id in self._graph.keys():
|
|
if node_id in marked:
|
|
continue
|
|
to_visit: list[str] = [node_id]
|
|
while len(to_visit) > 0:
|
|
node = to_visit[-1]
|
|
if node in marked:
|
|
if node not in result:
|
|
result.append(node)
|
|
to_visit.remove(node)
|
|
else:
|
|
marked.add(node)
|
|
if len(self._graph[node]) <= 0:
|
|
continue
|
|
for child_id in self._graph[node]:
|
|
if child_id not in marked:
|
|
to_visit.append(child_id)
|
|
result.reverse()
|
|
return result
|