From 8a965c8e6e3433d536341e314b7367d6aac816dd Mon Sep 17 00:00:00 2001 From: Akumatic Date: Fri, 6 Dec 2019 17:32:20 +0100 Subject: [PATCH] Refactor 2019 day 06: Added cache --- 2019/06/code.py | 82 ++++++++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 46 deletions(-) diff --git a/2019/06/code.py b/2019/06/code.py index eaffaec..69bcc90 100644 --- a/2019/06/code.py +++ b/2019/06/code.py @@ -2,60 +2,50 @@ def readFile(): with open(f"{__file__.rstrip('code.py')}input.txt", "r") as f: - return [data[:-1].split(")") for data in f.readlines()] + lines = [data[:-1].split(")") for data in f.readlines()] + return {line[1] : line[0] for line in lines} -def parse(lines): - data = {} - for line in lines: - data[line[1]] = line[0] - return data +def countOrbits(data, cache, node): + if node in cache: return cache[node] + cache[node] = 0 if node not in data else 1 + countOrbits(data,cache,data[node]) + return cache[node] -def countOrbits(data, node): - if node not in data: - return 0 - return 1 + countOrbits(data, data[node]) +def getIntersection(data, cache, node1, node2): + if cache[node1] > cache[node2]: node1, node2 = node2, node1 + parents = set() + # get elements of shorter path + while data[node1] in data: + parents.add(data[node1]) + node1 = data[node1] + parents.add(data[node1]) + # look for first node present in both paths + while node2 not in parents: + node2 = data[node2] + return node2 -def getIntersection(data, node1, node2): - node = node1 - parents = [] +def part1(vals : dict, cache): + return sum([countOrbits(vals,cache,val) for val in vals]) - while data[node] in data: - parents.append(data[node]) - node = data[node] - parents.append(data[node]) - - node = node2 - while node not in parents: - node = data[node] - - return node - -def getJumps(data, start, goal, intersection): - iDist = countOrbits(data, intersection) - return countOrbits(data, start) + countOrbits(data, goal) - 2*iDist - 2 - -def part1(vals : list): - return sum([countOrbits(vals, val) for val in vals]) - -def part2(vals : list): - intersection = getIntersection(vals, "YOU", "SAN") - return getJumps(vals, "YOU", "SAN", intersection) +def part2(vals : dict, cache): + intersection = getIntersection(vals,cache,"YOU","SAN") + return cache["YOU"] + cache["SAN"] - 2*cache[intersection] - 2 def test(): - lines = [["COM","B"],["B","C"],["C","D"],["D","E"],["E","F"], - ["B","G"],["G","H"],["D","I"],["E","J"],["J","K"],["K","L"]] - vals = parse(lines) - assert countOrbits(vals, "D") == 3 - assert countOrbits(vals, "L") == 7 - assert countOrbits(vals, "COM") == 0 - assert sum([countOrbits(vals, val) for val in vals]) == 42 + vals, cache = {"B":"COM","C":"B","D":"C","E":"D","F":"E","G":"B", + "H":"G","I":"D","J":"E","K":"J","L":"K"}, {} + assert countOrbits(vals,cache,"D") == 3 + assert countOrbits(vals,cache,"L") == 7 + assert not countOrbits(vals,cache,"COM") + assert sum([countOrbits(vals,cache,val) for val in vals]) == 42 vals["YOU"] = "K" vals["SAN"] = "I" - assert getIntersection(vals, "YOU", "SAN") == "D" - assert getJumps(vals, "YOU", "SAN", "D") == 4 + countOrbits(vals,cache,"YOU"), countOrbits(vals,cache,"SAN") + assert getIntersection(vals,cache,"YOU","SAN") == "D" + assert countOrbits(vals,cache,"YOU") + countOrbits(vals,cache,"SAN") - \ + 2*cache["D"] == 6 if __name__ == "__main__": test() - vals = parse(readFile()) - print(f"Part 1: {part1(vals)}") - print(f"Part 2: {part2(vals)}") \ No newline at end of file + vals, cache = readFile(), {} + print(f"Part 1: {part1(vals, cache)}") + print(f"Part 2: {part2(vals, cache)}") \ No newline at end of file