diff --git a/erpnext/hr/page/organizational_chart/organizational_chart.py b/erpnext/hr/page/organizational_chart/organizational_chart.py index 3674912dc06..717b56138e5 100644 --- a/erpnext/hr/page/organizational_chart/organizational_chart.py +++ b/erpnext/hr/page/organizational_chart/organizational_chart.py @@ -1,4 +1,5 @@ import frappe +from frappe.query_builder.functions import Count @frappe.whitelist() @@ -15,31 +16,34 @@ def get_children(parent=None, company=None, exclude_node=None): if exclude_node: filters.append(["name", "!=", exclude_node]) - employees = frappe.get_list( + employees = frappe.get_all( "Employee", - fields=["employee_name as name", "name as id", "reports_to", "image", "designation as title"], + fields=[ + "employee_name as name", + "name as id", + "lft", + "rgt", + "reports_to", + "image", + "designation as title", + ], filters=filters, order_by="name", ) for employee in employees: - is_expandable = frappe.db.count("Employee", filters={"reports_to": employee.get("id")}) - employee.connections = get_connections(employee.id) - employee.expandable = 1 if is_expandable else 0 + employee.connections = get_connections(employee.id, employee.lft, employee.rgt) + employee.expandable = bool(employee.connections) return employees -def get_connections(employee): - num_connections = 0 +def get_connections(employee: str, lft: int, rgt: int) -> int: + Employee = frappe.qb.DocType("Employee") + query = ( + frappe.qb.from_(Employee) + .select(Count(Employee.name)) + .where((Employee.lft > lft) & (Employee.rgt < rgt)) + ).run() - nodes_to_expand = frappe.get_list("Employee", filters=[["reports_to", "=", employee]]) - num_connections += len(nodes_to_expand) - - while nodes_to_expand: - parent = nodes_to_expand.pop(0) - descendants = frappe.get_list("Employee", filters=[["reports_to", "=", parent.name]]) - num_connections += len(descendants) - nodes_to_expand.extend(descendants) - - return num_connections + return query[0][0] diff --git a/erpnext/hr/page/organizational_chart/test_organizational_chart.py b/erpnext/hr/page/organizational_chart/test_organizational_chart.py new file mode 100644 index 00000000000..1c3c24845c5 --- /dev/null +++ b/erpnext/hr/page/organizational_chart/test_organizational_chart.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors +# License: GNU General Public License v3. See license.txt + +import unittest + +import frappe +from frappe.tests.utils import FrappeTestCase + +from erpnext.hr.doctype.employee.test_employee import make_employee +from erpnext.hr.page.organizational_chart.organizational_chart import get_children + + +class TestOrganizationalChart(FrappeTestCase): + def setUp(self): + self.company = create_company("Test Org Chart").name + frappe.db.delete("Employee", {"company": self.company}) + + def test_get_children(self): + company = create_company("Test Org Chart").name + emp1 = make_employee("testemp1@mail.com", company=self.company) + emp2 = make_employee("testemp2@mail.com", company=self.company, reports_to=emp1) + emp3 = make_employee("testemp3@mail.com", company=self.company, reports_to=emp1) + make_employee("testemp4@mail.com", company=self.company, reports_to=emp2) + + # root node + children = get_children(company=self.company) + self.assertEqual(len(children), 1) + self.assertEqual(children[0].id, emp1) + self.assertEqual(children[0].connections, 3) + + # root's children + children = get_children(parent=emp1, company=self.company) + self.assertEqual(len(children), 2) + self.assertEqual(children[0].id, emp2) + self.assertEqual(children[0].connections, 1) + self.assertEqual(children[1].id, emp3) + self.assertEqual(children[1].connections, 0) + + +def create_company(name): + if frappe.db.exists("Company", name): + return frappe.get_doc("Company", name) + + company = frappe.new_doc("Company") + company.update( + { + "company_name": name, + "default_currency": "USD", + "country": "United States", + } + ) + return company.insert()