Skip to content

Backtracking Search Algorithm in Python

This tutorial includes an implementation of a backtracking search algorithm in Python. Backtracking search is an recursive algorithm that is used to find solutions to constraint satisfaction problems (CSP). I am going to try to solve a sodoku and a scheduling problem in this tutorial, both of these problems have constraints but the scheduling problem also have a time variable that can be minimized.

A backtracking search algorithm tries to assign a value to a variable on each recursion and backtracks (goes back and tries another value) if it has no more legal values to assign. A pure backtracking algorithm can be rather slow, but we can improve it’s performance by guidning it in the correct direction.

We can use Arc consistency to speed up backtracking, this means that we only include legal values in the domain for each variable and therefore have less values to chose from. We can also use the most constrained variable (minimum-remaining-values) heuristic to select the variable with fewest legal values first.

Sodoku

This code can be used to solve sodoku puzzles of different sizes. I have included two backtracking algoritms in this code, backtracking_search_1 and an optimized version called backtracking_search_2. Simple sodoku puzzles can be solved in a reasonable time with the first algorithm while harder puzzles must be solved with the second version.

# Import libraries
import copy

# This class represent a sodoku
class Sodoku():
    
    # Create a new sodoku
    def __init__(self, state:[], size:int, sub_column_size:int, sub_row_size:int):
        
        # Set values for instance variables
        self.state = state
        self.size = size
        self.sub_column_size = sub_column_size
        self.sub_row_size = sub_row_size
        self.domains = {}

        # Create domains for numbers by using Arc consistency
        # Arc consistency: include only consistent numbers in the domain for each cell
        self.update_domains()
        
    # Update domains for cells
    def update_domains(self):

        # Reset domains
        self.domains = {}
        
        # Create an array with numbers
        numbers = []

        # Loop the state (puzzle or grid)
        for y in range(self.size):
            for x in range(self.size):
                
                # Check if a cell is empty
                if (self.state[y][x] == 0):

                    # Loop all possible numbers
                    numbers = []
                    for number in range(1, self.size + 1):

                        # Check if the number is consistent
                        if(self.is_consistent(number, y, x) == True):
                            numbers.append(number)

                    # Add numbers to a domain
                    if(len(numbers) > 0):
                        self.domains[(y, x)] = numbers
                            
    # Check if a number can be put in a cell
    def is_consistent(self, number:int, row:int, column:int) -> bool:

        # Check a row
        for x in range(self.size):

            # Return false if the number exists in the row
            if self.state[row][x] == number:
                return False

        # Check a column
        for y in range(self.size):
            
            # Return false if the number exists in the column
            if self.state[y][column] == number:
                return False

        # Calculate row start and column start
        row_start = (row//self.sub_row_size)*self.sub_row_size
        col_start = (column//self.sub_column_size)*self.sub_column_size;

        # Check sub matrix
        for y in range(row_start, row_start+self.sub_row_size):
            for x in range(col_start, col_start+self.sub_column_size):
                
                # Return false if the number exists in the submatrix
                if self.state[y][x]== number:
                    return False

        # Return true if no conflicts has been found
        return True

    # Get the first empty cell (backtracking_search_1)
    def get_first_empty_cell(self) -> ():

        # Loop the state (puzzle or grid)
        for y in range(self.size):
            for x in range(self.size):
                
                # Check if the cell is empty
                if (self.state[y][x] == 0):
                    return (y, x)

        # Return false
        return (None, None)

    # Get the most constrained cell (backtracking_search_2)
    def get_most_constrained_cell(self) -> ():

        # No empty cells left, return None
        if(len(self.domains) == 0):
            return (None, None)

        # Sort domains by value count (we want empty cells with most constraints at the top)
        keys = sorted(self.domains, key=lambda k: len(self.domains[k]))

        # Return the first key in the dictionary
        return keys[0]

    # Check if the puzzle is solved
    def solved(self) -> bool:

        # Loop the state (puzzle or grid)
        for y in range(self.size):
            for x in range(self.size):
                
                # Check if the cell is empty
                if (self.state[y][x] == 0):
                    return False

        # Return true
        return True

    # Solve the puzzle
    def backtracking_search_1(self) -> bool:

        # Get the first empty cell
        y, x = self.get_first_empty_cell()

        # Check if the puzzle is solved
        if(y == None or x == None):
            return True

        # Assign a number
        for number in range(1, self.size + 1):

            # Check if the number is consistent
            if(self.is_consistent(number, y, x)):

                # Assign the number
                self.state[y][x] = number

                # Backtracking
                if (self.backtracking_search_1() == True):
                    return True

                # Reset assignment
                self.state[y][x] = 0

        # No number could be assigned, return false
        return False

    # Solve the puzzle (optimized version)
    def backtracking_search_2(self) -> bool:

        # Check if the puzzle is solved
        if(self.solved() == True):
            return True

        # Get a an empty cell
        y, x = self.get_most_constrained_cell()
        
        # No good cell was found, retry
        if (y == None or x == None):
            return False

        # Get possible numbers in domain
        numbers = copy.deepcopy(self.domains.get((y, x)))

        # Assign a number
        for number in numbers:

            # Check if the number is consistent
            if(self.is_consistent(number, y, x)):

                # Assign the number
                self.state[y][x] = number

                # Remove the entire domain
                del self.domains[(y, x)]

                # Backtracking
                if (self.backtracking_search_2() == True):
                    return True

                # Reset assignment
                self.state[y][x] = 0

                # Update domains
                self.update_domains()

        # No number could be assigned, return false
        return False

    # Print the current state
    def print_state(self):
        for y in range(self.size):
            print('| ', end='')
            if y != 0 and y % self.sub_row_size == 0:
                for j in range(self.size):
                    print(' - ', end='')
                    if (j + 1) < self.size and (j + 1) % self.sub_column_size == 0:
                        print(' + ', end='')   
                print(' |')
                print('| ', end='')
            for x in range(self.size):
                if x != 0 and x % self.sub_column_size == 0:
                    print(' | ', end='')
                digit = str(self.state[y][x]) if len(str(self.state[y][x])) > 1 else ' ' + str(self.state[y][x])
                print('{0} '.format(digit), end='')
            print(' |')
        
# The main entry point for this module
def main():

    # Small puzzle 81 (9x9 matrix and 3x3 submatrixes)
    #data = '4173698.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......'
    #data = data.strip().replace('.', '0')
    #numbers = [int(i) for i in data]
    #size = 9 # 9 columns and 9 rows
    #sub_column_size = 3 # 3 columns in each submatrix
    #sub_row_size = 3 # 3 rows in each submatrix
    
    # Larger puzzle 144 (12x12 matrix and 4x3 submatrixes)
    numbers = [7,0,5,0,4,0,0,1,0,0,3,6,9,6,0,0,7,0,0,0,0,1,4,0,0,2,0,0,0,0,3,6,0,0,0,8,0,0,0,10,8,0,0,9,3,0,0,0,11,0,12,1,0,0,0,0,10,0,5,9,0,0,6,0,0,3,12,0,0,0,0,0,0,0,0,0,0,7,4,0,0,9,0,0,2,12,0,7,0,0,0,0,4,10,0,5,0,0,0,11,5,0,0,2,7,0,0,0,1,0,0,0,3,6,0,0,0,0,8,0,0,11,3,0,0,0,0,5,0,0,9,7,10,5,0,0,2,0,0,7,0,3,0,1]
    size = 12 # 12 columns and 12 rows
    sub_column_size = 4 # 4 columns in each submatrix
    sub_row_size = 3 # 3 rows in each submatrix
    
    # Create the initial state
    initial_state = []
    row = []
    counter = 0

    # Loop numbers and append to initial state
    for number in numbers:
        counter += 1
        row.append(number)
        if(counter >= size):
            initial_state.append(row)
            row = []
            counter = 0

    # Create a sodoku
    sodoku = Sodoku(initial_state, size, sub_column_size, sub_row_size)

    # Print sodoku
    print('Puzzle input:')
    sodoku.print_state()

    # Solve sodoku with optimized version
    sodoku.backtracking_search_2()

    # Print sodoku
    print('\nPuzzle solution:')
    sodoku.print_state()
    print()


# Tell python to run main method
if __name__ == "__main__": main()
Puzzle input:
|  7  0  5  0  |  4  0  0  1  |  0  0  3  6  |
|  9  6  0  0  |  7  0  0  0  |  0  1  4  0  |
|  0  2  0  0  |  0  0  3  6  |  0  0  0  8  |
|  -  -  -  -  +  -  -  -  -  +  -  -  -  -  |
|  0  0  0 10  |  8  0  0  9  |  3  0  0  0  |
| 11  0 12  1  |  0  0  0  0  | 10  0  5  9  |
|  0  0  6  0  |  0  3 12  0  |  0  0  0  0  |
|  -  -  -  -  +  -  -  -  -  +  -  -  -  -  |
|  0  0  0  0  |  0  7  4  0  |  0  9  0  0  |
|  2 12  0  7  |  0  0  0  0  |  4 10  0  5  |
|  0  0  0 11  |  5  0  0  2  |  7  0  0  0  |
|  -  -  -  -  +  -  -  -  -  +  -  -  -  -  |
|  1  0  0  0  |  3  6  0  0  |  0  0  8  0  |
|  0 11  3  0  |  0  0  0  5  |  0  0  9  7  |
| 10  5  0  0  |  2  0  0  7  |  0  3  0  1  |

Puzzle solution:
|  7 10  5  8  |  4 12  2  1  |  9 11  3  6  |
|  9  6 11  3  |  7 10  5  8  | 12  1  4  2  |
| 12  2  1  4  |  9 11  3  6  |  5  7 10  8  |
|  -  -  -  -  +  -  -  -  -  +  -  -  -  -  |
|  4  7  2 10  |  8  5  1  9  |  3 12  6 11  |
| 11  3 12  1  |  6  2  7  4  | 10  8  5  9  |
|  8  9  6  5  | 11  3 12 10  |  1  2  7  4  |
|  -  -  -  -  +  -  -  -  -  +  -  -  -  -  |
|  5  1 10  6  | 12  7  4 11  |  8  9  2  3  |
|  2 12  9  7  |  1  8  6  3  |  4 10 11  5  |
|  3  8  4 11  |  5  9 10  2  |  7  6  1 12  |
|  -  -  -  -  +  -  -  -  -  +  -  -  -  -  |
|  1  4  7  2  |  3  6  9 12  | 11  5  8 10  |
|  6 11  3 12  | 10  1  8  5  |  2  4  9  7  |
| 10  5  8  9  |  2  4 11  7  |  6  3 12  1  |

Job Shop Problem

This problem is about scheduling tasks in jobs where each task must be performed in a certain machine (Job Shop Problem). Each task must be executed in a certain order according to the job description and the output will be a shedule with end times for each machine.

# This class represent a task
class Task:

    # Create a new task
    def __init__(self, tuple:()):
        
        # Set values for instance variables
        self.machine_id, self.processing_time = tuple

    # Sort
    def __lt__(self, other):
        return self.processing_time < other.processing_time

    # Print
    def __repr__(self):
        return ('(Machine: {0}, Time: {1})'.format(self.machine_id, self.processing_time))

# This class represent an assignment
class Assignment:

    # Create a new assignment
    def __init__(self, job_id:int, task_id:int, start_time:int, end_time:int):

        # Set values for instance variables
        self.job_id = job_id
        self.task_id = task_id
        self.start_time = start_time
        self.end_time = end_time

    # Print
    def __repr__(self):
        return ('(Job: {0}, Task: {1}, Start: {2}, End: {3})'.format(self.job_id, self.task_id, self.start_time, self.end_time))    


# This class represents a schedule
class Schedule:

    # Create a new schedule
    def __init__(self, jobs:[]):
        
        # Set values for instance variables
        self.jobs = jobs
        self.tasks = {}
        for i in range(len(self.jobs)):
            for j in range(len(self.jobs[i])):
                self.tasks[(i, j)] = Task(self.jobs[i][j])
        self.assignments = {}

    # Get the next assignment
    def backtracking_search(self) -> bool:

        # Prefer tasks with an early end time
        best_task_key = None
        best_machine_id = None
        best_assignment = None

        # Loop all tasks
        for key, task in self.tasks.items():

            # Get task variables
            job_id, task_id = key
            machine_id = task.machine_id
            processing_time = task.processing_time

            # Check if the task needs a predecessor, find it if needs it
            predecessor = None if task_id > 0 else Assignment(0, 0, 0, 0)
            if (task_id > 0):

                # Loop assignments
                for machine, machine_tasks in self.assignments.items():

                    # Break out from the loop if a predecessor has been found
                    if(predecessor != None):
                        break

                    # Loop machine tasks
                    for t in machine_tasks:

                        # Check if a predecessor exsits
                        if(t.job_id == job_id and t.task_id == (task_id - 1)):
                            predecessor = t
                            break

            # Continue if the task needs a predecessor and if it could not be found
            if(predecessor == None):
                continue

            # Get an assignment
            assignment = self.assignments.get(machine_id)

            # Calculate the end time
            end_time = processing_time
            if(assignment != None):
                end_time += max(predecessor.end_time, assignment[-1].end_time)
            else:
                end_time += predecessor.end_time

            # Check if we should update the best assignment
            if(best_assignment == None or end_time < best_assignment.end_time):
                best_task_key = key
                best_machine_id = machine_id
                best_assignment = Assignment(job_id, task_id, end_time - processing_time, end_time)

        # Return failure if we can not find an assignment (Problem not solvable)
        if(best_assignment == None):
            return False

        # Add the best assignment
        assignment = self.assignments.get(best_machine_id)
        if(assignment == None):
            self.assignments[best_machine_id] = [best_assignment]
        else:
            assignment.append(best_assignment)

        # Remove the task
        del self.tasks[best_task_key]

        # Check if we are done
        if(len(self.tasks) <= 0):
            return True

        # Backtrack
        self.backtracking_search()

# The main entry point for this module
def main():

    # Input data: Task = (machine_id, time)
    jobs = [[(0, 3), (1, 2), (2, 2)], # Job 0
            [(0, 2), (2, 1), (1, 4)], # Job 1
            [(1, 4), (2, 3)]] # Job 2
    
    # Create a schedule
    schedule = Schedule(jobs)

    # Find a solution
    schedule.backtracking_search()

    # Print the solution
    print('Final solution:')
    for key, value in sorted(schedule.assignments.items()):
        print(key, value)
    print()
    
# Tell python to run main method
if __name__ == "__main__": main()
Final solution:
0 [(Job: 1, Task: 0, Start: 0, End: 2), (Job: 0, Task: 0, Start: 2, End: 5)]
1 [(Job: 2, Task: 0, Start: 0, End: 4), (Job: 0, Task: 1, Start: 5, End: 7), (Job: 1, Task: 2, Start: 7, End: 11)]
2 [(Job: 1, Task: 1, Start: 2, End: 3), (Job: 2, Task: 1, Start: 4, End: 7), (Job: 0, Task: 2, Start: 7, End: 9)]
Tags:

Leave a Reply

Your email address will not be published. Required fields are marked *