diff --git a/rsconcept/backend/apps/rsform/graph.py b/rsconcept/backend/apps/rsform/graph.py new file mode 100644 index 00000000..95125c77 --- /dev/null +++ b/rsconcept/backend/apps/rsform/graph.py @@ -0,0 +1,73 @@ +''' Utility: Graph implementation. ''' +from typing import Dict, Iterable, Optional, cast + + +class Graph: + ''' Directed graph. ''' + def __init__(self, graph: Optional[Dict[str, list[str]]]=None): + if graph is None: + self._graph = cast(Dict[str, list[str]], dict()) + else: + self._graph = graph + + def contains(self, node_id: str) -> bool: + ''' Check if node is in graph. ''' + return node_id in self._graph + + def has_edge(self, id_from: str, id_to: str) -> bool: + ''' Check if edge is in graph. ''' + return self.contains(id_from) and id_to in self._graph[id_from] + + def add_node(self, node_id: str): + ''' Add node to graph. ''' + if not self.contains(node_id): + self._graph[node_id] = [] + + def add_edge(self, id_from: str, id_to: str): + ''' Add edge to graph. ''' + self.add_node(id_from) + self.add_node(id_to) + if id_to not in self._graph[id_from]: + self._graph[id_from].append(id_to) + + def expand_outputs(self, origin: Iterable[str]) -> list[str]: + ''' Expand origin nodes forward through graph edges. ''' + result: list[str] = [] + marked: set[str] = set(origin) + for node_id in origin: + if self.contains(node_id): + for child_id in self._graph[node_id]: + if child_id not in marked and child_id not in result: + result.append(child_id) + position: int = 0 + while position < len(result): + node_id = result[position] + position += 1 + if (node_id not in marked): + marked.add(node_id) + for child_id in self._graph[node_id]: + if child_id not in marked and child_id not in result: + result.append(child_id) + return result + + def topological_order(self) -> list[str]: + ''' Return nodes in topological order. ''' + result: list[str] = [] + marked: set[str] = set() + for node_id in self._graph.keys(): + if node_id not in marked: + to_visit: list[str] = [node_id] + while len(to_visit) > 0: + node = to_visit[-1] + if node in marked: + if node not in result: + result.append(node) + to_visit.remove(node) + else: + marked.add(node) + if len(self._graph[node]) > 0: + for child_id in self._graph[node]: + if child_id not in marked: + to_visit.append(child_id) + result.reverse() + return result diff --git a/rsconcept/backend/apps/rsform/tests/__init__.py b/rsconcept/backend/apps/rsform/tests/__init__.py index 9f51ce17..48cc1f6b 100644 --- a/rsconcept/backend/apps/rsform/tests/__init__.py +++ b/rsconcept/backend/apps/rsform/tests/__init__.py @@ -3,3 +3,4 @@ from .t_imports import * from .t_views import * from .t_models import * from .t_serializers import * +from .t_graph import * diff --git a/rsconcept/backend/apps/rsform/tests/t_graph.py b/rsconcept/backend/apps/rsform/tests/t_graph.py new file mode 100644 index 00000000..4cea573f --- /dev/null +++ b/rsconcept/backend/apps/rsform/tests/t_graph.py @@ -0,0 +1,64 @@ +''' Unit tests: graph. ''' +import unittest + +from apps.rsform.graph import Graph + + +class TestGraph(unittest.TestCase): + ''' Test class for graph. ''' + def test_construction(self): + ''' Test graph construction methods. ''' + graph = Graph() + self.assertFalse(graph.contains('X1')) + + graph.add_node('X1') + self.assertTrue(graph.contains('X1')) + + graph.add_edge('X2', 'X3') + self.assertTrue(graph.contains('X2')) + self.assertTrue(graph.contains('X3')) + self.assertTrue(graph.has_edge('X2', 'X3')) + self.assertFalse(graph.has_edge('X3', 'X2')) + + graph = Graph({'X1': ['X3', 'X4'], 'X2': ['X1'], 'X3': [], 'X4': [], 'X5': []}) + self.assertTrue(graph.contains('X1')) + self.assertTrue(graph.contains('X5')) + self.assertTrue(graph.has_edge('X1', 'X3')) + self.assertTrue(graph.has_edge('X2', 'X1')) + + def test_expand_outputs(self): + ''' Test Method: Graph.expand_outputs. ''' + graph = Graph({ + 'X1': ['X2'], + 'X2': ['X3', 'X5'], + 'X3': [], + 'X5': ['X6'], + 'X6': ['X1'], + 'X7': [] + }) + self.assertEqual(graph.expand_outputs([]), []) + self.assertEqual(graph.expand_outputs(['X3']), []) + self.assertEqual(graph.expand_outputs(['X7']), []) + self.assertEqual(graph.expand_outputs(['X2', 'X5']), ['X3', 'X6', 'X1']) + + def test_topological_order(self): + ''' Test Method: Graph.topological_order. ''' + self.assertEqual(Graph().topological_order(), []) + graph = Graph({ + 'X1': [], + 'X2': ['X1'], + 'X3': [], + 'X4': ['X3'], + 'X5': ['X6'], + 'X6': ['X1', 'X2'] + }) + self.assertEqual(graph.topological_order(), ['X5', 'X6', 'X4', 'X3', 'X2', 'X1']) + + graph = Graph({ + 'X1': ['X1'], + 'X2': ['X4'], + 'X3': ['X2'], + 'X4': [], + 'X5': ['X2'], + }) + self.assertEqual(graph.topological_order(), ['X5', 'X3', 'X2', 'X4', 'X1']) diff --git a/rsconcept/frontend/src/utils/Graph.test.ts b/rsconcept/frontend/src/utils/Graph.test.ts index 063b42f4..d577db4b 100644 --- a/rsconcept/frontend/src/utils/Graph.test.ts +++ b/rsconcept/frontend/src/utils/Graph.test.ts @@ -73,7 +73,7 @@ describe('Testing Graph editing', () => { describe('Testing Graph sort', () => { test('topological order', () => { const graph = new Graph([[9, 1], [9, 2], [2, 1], [4, 3], [5, 9]]); - expect(graph.tolopogicalOrder()).toStrictEqual([5, 4, 3, 9, 1, 2]); + expect(graph.tolopogicalOrder()).toStrictEqual([5, 4, 3, 9, 2, 1]); }); }); diff --git a/rsconcept/frontend/src/utils/Graph.ts b/rsconcept/frontend/src/utils/Graph.ts index 4e873e6d..206758a7 100644 --- a/rsconcept/frontend/src/utils/Graph.ts +++ b/rsconcept/frontend/src/utils/Graph.ts @@ -194,20 +194,19 @@ export class Graph { tolopogicalOrder(): number[] { const result: number[] = []; const marked = new Map(); + const toVisit: number[] = []; this.nodes.forEach(node => { if (marked.get(node.id)) { return; } - const toVisit: number[] = [node.id]; - let index = 0; + toVisit.push(node.id) while (toVisit.length > 0) { - const item = toVisit[index]; + const item = toVisit[toVisit.length - 1]; if (marked.get(item)) { - if (!result.find(id => id ===item)) { + if (!result.find(id => id === item)) { result.push(item); } - toVisit.splice(index, 1); - index -= 1; + toVisit.pop(); } else { marked.set(item, true); const itemNode = this.nodes.get(item); @@ -218,12 +217,8 @@ export class Graph { } }); } - if (index + 1 < toVisit.length) { - index += 1; - } } } - marked }); return result.reverse(); }