Sort a linked list of 0s, 1s and 2s

Last Updated : 2 Sep, 2025

Given a head of linked list containing nodes with values 0, 1, and 2 only, rearrange the list so that all 0s appear first, followed by all 1s, and then all 2s at the end, while maintaining the relative order within each group.

Examples:

Input:

1

Output: 0 1 1 2 2 2 2 2
Explanation: All the 0s are segregated to the left end of the linked list, 2s to the right end of the list, and 1s in between.

2-

Input:

3-

Output: 0 1 2 2
Explanation: After arranging all the 0s, 1s and 2s in the given format, the output will be 0 -> 1 2 2.

4
Try it on GfG Practice
redirect icon

[Naive Approach] Using an Extra Array - O(n × log n) Time and O(n) Space

The idea is to first convert the linked list into an array to easily leverage sorting, as sorting an array is straightforward and efficient. After sorting the array, we traverse the linked list again and reassign the sorted values back to the nodes.

C++
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

class Node {
public:
    int data;
    Node* next;

    Node(int new_data) {
        data = new_data;
        next = nullptr;
    }
};

Node* segregate(Node* head) {
    if (!head || !(head->next)) 
        return head; 

    // convert linked list to array
    vector<int> arr;
    Node* curr = head;
    while (curr) {
        arr.push_back(curr->data);
        curr = curr->next;
    }

    // sort the array 
    sort(arr.begin(), arr.end());

    // reassign sorted values back to the linked list
    curr = head;
    for (int i = 0; i < arr.size(); i++) {
        curr->data = arr[i];
        curr = curr->next;
    }

    return head; 
} 

int main() {
    
    Node* head = new Node(1);
    head->next = new Node(2);
    head->next->next = new Node(2);
    head->next->next->next = new Node(1);
    head->next->next->next->next = new Node(2);
    head->next->next->next->next->next = new Node(0);
    head->next->next->next->next->next->next = new Node(2);
    head->next->next->next->next->next->next->next = new Node(2);
    
    head = segregate(head);

    while (head != nullptr) {
        cout << head->data;
        if(head->next != nullptr){
            cout << " -> ";
        }
        head = head->next;
    }
    cout << "\n";
    return 0;
}
Java
import java.util.ArrayList;
import java.util.Collections;

class Node {
    int data;
    Node next;

    Node(int new_data) {
        data = new_data;
        next = null;
    }
}

class GfG {

    static Node segregate(Node head) {
        if (head == null || head.next == null)
            return head;

        // convert linked list to array
        ArrayList<Integer> arr = new ArrayList<>();
        Node curr = head;
        while (curr != null) {
            arr.add(curr.data);
            curr = curr.next;
        }

        // sort the array 
        Collections.sort(arr);

        // reassign sorted values back to the linked list
        curr = head;
        for (int i = 0; i < arr.size(); i++) {
            curr.data = arr.get(i);
            curr = curr.next;
        }

        return head;
    }

    public static void main(String[] args) {
        Node head = new Node(1);
        head.next = new Node(2);
        head.next.next = new Node(2);
        head.next.next.next = new Node(1);
        head.next.next.next.next = new Node(2);
        head.next.next.next.next.next = new Node(0);
        head.next.next.next.next.next.next = new Node(2);
        head.next.next.next.next.next.next.next = new Node(2);

        head = segregate(head);

        while (head != null) {
            System.out.print(head.data);
            if(head.next != null){
                System.out.print(" -> ");
            }
            head = head.next;
        }
        System.out.println();
    }
}
Python
class Node:
    def __init__(self, new_data):
        
        self.data = new_data
        self.next = None

def segregate(head):
    if not head or not head.next:
        return head

    # convert linked list to array
    arr = []
    curr = head
    while curr:
        arr.append(curr.data)
        curr = curr.next

    # sort the array 
    arr.sort()

    # reassign sorted values back to the linked list
    curr = head
    for i in range(len(arr)):
        curr.data = arr[i]
        curr = curr.next

    return head

if __name__ == "__main__":
    head = Node(1)
    head.next = Node(2)
    head.next.next = Node(2)
    head.next.next.next = Node(1)
    head.next.next.next.next = Node(2)
    head.next.next.next.next.next = Node(0)
    head.next.next.next.next.next.next = Node(2)
    head.next.next.next.next.next.next.next = Node(2)

    head = segregate(head)

    while head:
        print(str(head.data), end="")
        if head.next != None:
            print(" -> ", end="")
        head = head.next
    print()
C#
using System;
using System.Collections.Generic;

class Node {
    public int data;
    public Node next;

    public Node(int new_data) {
        data = new_data;
        next = null;
    }
}

class GfG {

    public static Node segregate(Node head) {
        if (head == null || head.next == null)
            return head;

        // convert linked list to array
        List<int> arr = new List<int>();
        Node curr = head;
        while (curr != null) {
            arr.Add(curr.data);
            curr = curr.next;
        }

        // sort the array 
        arr.Sort();

        // reassign sorted values back to the linked list
        curr = head;
        for (int i = 0; i < arr.Count; i++) {
            curr.data = arr[i];
            curr = curr.next;
        }

        return head;
    }

    public static void Main() {
        Node head = new Node(1);
        head.next = new Node(2);
        head.next.next = new Node(2);
        head.next.next.next = new Node(1);
        head.next.next.next.next = new Node(2);
        head.next.next.next.next.next = new Node(0);
        head.next.next.next.next.next.next = new Node(2);
        head.next.next.next.next.next.next.next = new Node(2);

        head = segregate(head);

        while (head != null) {
            Console.Write(head.data);
            if(head.next != null){
                Console.Write(" -> ");
            }
            head = head.next;
        }
        Console.WriteLine();
    }
}
JavaScript
class Node {
    constructor(new_data) {
        this.data = new_data;
        this.next = null;
    }
}

function segregate(head) {
    if (!head || !head.next)
        return head;

    // count the number of 0s, 1s, and 2s
    let count = [0, 0, 0];
    let curr = head;
    while (curr) {
        count[curr.data]++;
        curr = curr.next;
    }

    // reassign values back to the linked list
    curr = head;
    let i = 0;
    while (curr) {
        if (count[i] === 0) {
            i++;
        } else {
            curr.data = i;
            count[i]--;
            curr = curr.next;
        }
    }

    return head;
}

// Driver Code
let head = new Node(1);
head.next = new Node(2);
head.next.next = new Node(2);
head.next.next.next = new Node(1);
head.next.next.next.next = new Node(2);
head.next.next.next.next.next = new Node(0);
head.next.next.next.next.next.next = new Node(2);
head.next.next.next.next.next.next.next = new Node(2);

head = segregate(head);

let temp = head;
while (temp) {
    process.stdout.write(temp.data.toString());
    if (temp.next !== null) {
        process.stdout.write(" -> ");
    }
    temp = temp.next;
}
console.log();

Output
0 -> 1 -> 1 -> 2 -> 2 -> 2 -> 2 -> 2

[Expected Approach - 1] Using Count of 0s, 1s and 2s - O(n) Time and O(1) Space

The idea is to traverse the linked list once and count the number of occurrences of 0s, 1s, and 2s. Once the counts are known, the linked list is traversed again, and the nodes are assigned the appropriate values based on the counts. First setting all nodes to 0, then to 1, and finally to 2.

C++
#include <iostream>
using namespace std;
class Node {
public:
    int data;
    Node* next;

    Node(int new_data) {
        data = new_data;
        next = nullptr;
    }
};

Node* segregate(Node* head) {
    if (!head || !(head->next)) 
        return head; 

    // Initialize counts for 0s, 1s, and 2s
    int cntZero = 0, cntOne = 0, cntTwo = 0;

    // Traverse the list to
    // count the occurrences of 0, 1, and 2
    Node* curr = head;
    while (curr) {
        if (curr->data == 0) {
            cntZero++;
        } else if (curr->data == 1) {
            cntOne++;
        } else {
            cntTwo++;
        }
        curr = curr->next;
    }

    // Rebuild the list with sorted values
    curr = head;
    
    // First add all the 0s
    while (cntZero--) {
        curr->data = 0;
        curr = curr->next;
    }

    // Then add all the 1s
    while (cntOne--) {
        curr->data = 1;
        curr = curr->next;
    }

    // Finally add all the 2s
    while (cntTwo--) {
        curr->data = 2;
        curr = curr->next;
    }

    return head; 
} 

int main() {
    
    Node* head = new Node(1);
    head->next = new Node(2);
    head->next->next = new Node(2);
    head->next->next->next = new Node(1);
    head->next->next->next->next = new Node(2);
    head->next->next->next->next->next = new Node(0);
    head->next->next->next->next->next->next = new Node(2);
    head->next->next->next->next->next->next->next = new Node(2);

    head = segregate(head);

    while (head != nullptr) {
        cout << head->data;
        if(head->next != nullptr){
            cout << " -> ";
        }
        head = head->next;
    }
    cout << "\n";
    return 0;
}
Java
class Node {
    int data;
    Node next;

    Node(int new_data) {
        data = new_data;
        next = null;
    }
}

class GfG {

    static Node segregate(Node head) {
        if (head == null || head.next == null)
            return head;

        // Initialize counts for 0s, 1s, and 2s
        int cntZero = 0, cntOne = 0, cntTwo = 0;

        // Traverse the list to count the 
        // occurrences of 0, 1, and 2
        Node curr = head;
        while (curr != null) {
            if (curr.data == 0) {
                cntZero++;
            } else if (curr.data == 1) {
                cntOne++;
            } else {
                cntTwo++;
            }
            curr = curr.next;
        }

        // Rebuild the list with sorted values
        curr = head;

        // First add all the 0s
        while (cntZero-- > 0) {
            curr.data = 0;
            curr = curr.next;
        }

        // Then add all the 1s
        while (cntOne-- > 0) {
            curr.data = 1;
            curr = curr.next;
        }

        // Finally add all the 2s
        while (cntTwo-- > 0) {
            curr.data = 2;
            curr = curr.next;
        }

        return head;
    }

    public static void main(String[] args) {
        Node head = new Node(1);
        head.next = new Node(2);
        head.next.next = new Node(2);
        head.next.next.next = new Node(1);
        head.next.next.next.next = new Node(2);
        head.next.next.next.next.next = new Node(0);
        head.next.next.next.next.next.next = new Node(2);
        head.next.next.next.next.next.next.next = new Node(2);

        head = segregate(head);

        while (head != null) {
            System.out.print(head.data);
            if(head.next != null){
                System.out.print(" -> ");
            }
            head = head.next;
        }
        System.out.println();
    }
}
Python
class Node:
    def __init__(self, new_data):

        self.data = new_data
        self.next = None

def segregate(head):
    if not head or not head.next:
        return head

    # initialize counts for 0s, 1s, and 2s
    cntZero = 0
    cntOne = 0
    cntTwo = 0

    # traverse the list to count the occurrences
    # of 0, 1, and 2
    curr = head
    while curr:
        if curr.data == 0:
            cntZero += 1
        elif curr.data == 1:
            cntOne += 1
        else:
            cntTwo += 1
        curr = curr.next

    # rebuild the list with sorted values
    curr = head

    # first add all the 0s
    while cntZero:
        curr.data = 0
        curr = curr.next
        cntZero -= 1

    # then add all the 1s
    while cntOne:
        curr.data = 1
        curr = curr.next
        cntOne -= 1

    # finally add all the 2s
    while cntTwo:
        curr.data = 2
        curr = curr.next
        cntTwo -= 1

    return head

if __name__ == "__main__":
    
    head = Node(1)
    head.next = Node(2)
    head.next.next = Node(2)
    head.next.next.next = Node(1)
    head.next.next.next.next = Node(2)
    head.next.next.next.next.next = Node(0)
    head.next.next.next.next.next.next = Node(2)
    head.next.next.next.next.next.next.next = Node(2)

    head = segregate(head)

    while head:
        print(str(head.data), end="")
        if head.next != None:
            print(" -> ", end="")
        head = head.next
    print()
C#
using System;
using System.Collections.Generic;

class Node {
    public int data;
    public Node next;

    public Node(int new_data) {
        data = new_data;
        next = null;
    }
}

class GfG {

    public static Node segregate(Node head) {
        if (head == null || head.next == null)
            return head;

        // initialize counts for 0s, 1s, and 2s
        int cntZero = 0, cntOne = 0, cntTwo = 0;

        // traverse the list to count the occurrences of 0, 1, and 2
        Node curr = head;
        while (curr != null) {
            if (curr.data == 0) {
                cntZero++;
            } else if (curr.data == 1) {
                cntOne++;
            } else {
                cntTwo++;
            }
            curr = curr.next;
        }

        // rebuild the list with sorted values
        curr = head;

        // first add all the 0s
        while (cntZero-- > 0) {
            curr.data = 0;
            curr = curr.next;
        }

        // then add all the 1s
        while (cntOne-- > 0) {
            curr.data = 1;
            curr = curr.next;
        }

        // finally add all the 2s
        while (cntTwo-- > 0) {
            curr.data = 2;
            curr = curr.next;
        }

        return head;
    }

    public static void Main() {
        Node head = new Node(1);
        head.next = new Node(2);
        head.next.next = new Node(2);
        head.next.next.next = new Node(1);
        head.next.next.next.next = new Node(2);
        head.next.next.next.next.next = new Node(0);
        head.next.next.next.next.next.next = new Node(2);
        head.next.next.next.next.next.next.next = new Node(2);

        head = segregate(head);

        while (head != null) {
            Console.Write(head.data);
            if(head.next != null){
                Console.Write(" -> ");
            }
            head = head.next;
        }
        Console.WriteLine();
    }
}
JavaScript
class Node {
    constructor(new_data) {
        this.data = new_data;
        this.next = null;
    }
}

function segregate(head) {
    if (!head || !head.next) return head;

    // dummy nodes for 0s, 1s, and 2s
    let zeroD = new Node(-1), oneD = new Node(-1), twoD = new Node(-1);

    // current tails for 0s, 1s, and 2s lists
    let zero = zeroD, one = oneD, two = twoD;

    // traverse the original list
    let curr = head;
    while (curr) {
        if (curr.data === 0) {
            zero.next = curr;
            zero = zero.next;
        } else if (curr.data === 1) {
            one.next = curr;
            one = one.next;
        } else {
            two.next = curr;
            two = two.next;
        }
        curr = curr.next;
    }

    // connect the three lists: 0s -> 1s -> 2s
    zero.next = oneD.next ? oneD.next : twoD.next;
    one.next = twoD.next;
    two.next = null;

    // new head will be next of dummy 0 node
    return zeroD.next;
}

// Driver Code
let head = new Node(1);
head.next = new Node(2);
head.next.next = new Node(2);
head.next.next.next = new Node(1);
head.next.next.next.next = new Node(2);
head.next.next.next.next.next = new Node(0);
head.next.next.next.next.next.next = new Node(2);
head.next.next.next.next.next.next.next = new Node(2);

head = segregate(head);

let temp = head;
while (temp) {
    process.stdout.write(temp.data.toString());
    if (temp.next != null) {
        process.stdout.write(" -> ");
    }
    temp = temp.next;
}
console.log();

Output
0 -> 1 -> 1 -> 2 -> 2 -> 2 -> 2 -> 2

[Expected Approach - 2] By Changing Links of Nodes - O(n) Time and O(1) Space

The idea is to maintain 3 pointers named zero, one and two to point to current ending nodes of linked lists containing 0, 1, and 2 respectively. For every traversed node, we attach it to the end of its corresponding list.

  • If the current node's value is 0, append it after pointer zero and move pointer zero to current node.
  • If the current node's value is 1, append it after pointer one and move pointer one to current node.
  • If the current node's value is 2, append it after pointer two and move pointer two to current node.

Finally, we link all three lists. To avoid many null checks, we use three dummy pointers zeroD, oneD and twoD that work as dummy headers of three lists.

C++
#include <iostream> 
using namespace std;

class Node {
public:
    int data;
    Node* next;

    Node(int new_data) {
        data = new_data;
        next = nullptr;
    }
};
 
Node* segregate(Node* head) {
	if (!head || !(head->next)) 
		return head; 

	Node* zeroD = new Node(0); 
	Node* oneD = new Node(0); 
	Node* twoD = new Node(0);

	Node *zero = zeroD, *one = oneD, *two = twoD; 

	// traverse list 
	Node* curr = head; 
	while (curr) { 
		if (curr->data == 0) { 
          	
          	// if the data of current node is 0, 
      		// append it to pointer zero and update zero
			zero->next = curr; 
			zero = zero->next; 
		} 
		else if (curr->data == 1) { 
          	
          	// if the data of current node is 1, 
      		// append it to pointer one and update one
			one->next = curr; 
			one = one->next; 
		} 
		else { 
          	// if the data of current node is 2, 
      		// append it to pointer two and update two
			two->next = curr; 
			two = two->next; 
		} 
		curr = curr->next; 
	} 

	// combine the three lists
	zero->next = (oneD->next) ? (oneD->next) : (twoD->next); 
	one->next = twoD->next; 
	two->next = NULL; 
  	
	// updated head 
	head = zeroD->next; 

	delete zeroD; 
	delete oneD; 
	delete twoD; 

	return head; 
} 

int main() {
    
    Node* head = new Node(1);
    head->next = new Node(2);
    head->next->next = new Node(2);
    head->next->next->next = new Node(1);
    head->next->next->next->next = new Node(2);
    head->next->next->next->next->next = new Node(0);
    head->next->next->next->next->next->next = new Node(2);
    head->next->next->next->next->next->next->next = new Node(2);

    head = segregate(head);

    while (head != nullptr) {
        cout << head->data;
        if(head->next != nullptr){
            cout << " -> ";
        }
        head = head->next;
    }
  	cout << "\n";
    return 0;
}
Java
// a linked list node
class Node {
    int data;
    Node next;

    Node(int new_data) {
        data = new_data;
        next = null;
    }
}
 
class GfG {
    
    static Node segregate(Node head) {
        if (head == null || head.next == null) 
            return head; 

        Node zeroD = new Node(0); 
        Node oneD = new Node(0); 
        Node twoD = new Node(0);

        Node zero = zeroD, one = oneD, two = twoD; 

        // traverse list 
        Node curr = head; 
        while (curr != null) { 
            if (curr.data == 0) { 
              	
                // if the data of current node is 0, 
                // append it to pointer zero and update zero
                zero.next = curr; 
                zero = zero.next; 
            } 
            else if (curr.data == 1) { 
              	
                // if the data of current node is 1, 
                // append it to pointer one and update one
                one.next = curr; 
                one = one.next; 
            } 
            else { 
              	
                // if the data of current node is 2, 
                // append it to pointer two and update two
                two.next = curr; 
                two = two.next; 
            } 
            curr = curr.next; 
        } 

        // combine the three lists
        zero.next = (oneD.next != null) ? (oneD.next) : (twoD.next); 
        one.next = twoD.next; 
        two.next = null; 
          
        // updated head 
        head = zeroD.next; 

        return head; 
    } 

    public static void main(String[] args) {
        
        Node head = new Node(1);
        head.next = new Node(2);
        head.next.next = new Node(2);
        head.next.next.next = new Node(1);
        head.next.next.next.next = new Node(2);
        head.next.next.next.next.next = new Node(0);
        head.next.next.next.next.next.next = new Node(2);
        head.next.next.next.next.next.next.next = new Node(2);
    
       
        head = segregate(head);
        while (head != null) {
            System.out.print(head.data);
            if(head.next != null){
                System.out.print(" -> ");
            }
            head = head.next;
        }
        System.out.println();
    }

}
Python
class Node:
  
    def __init__(self, new_data):
        self.data = new_data
        self.next = None

def segregate(head):
    if not head or not head.next:
        return head
 
    zeroD = Node(0)
    oneD = Node(0)
    twoD = Node(0)

    # initialize current pointers for three 
    # lists 
    zero = zeroD
    one = oneD
    two = twoD
 
    curr = head
    while curr:
        if curr.data == 0:
            
            # if the data of current node is 0, 
            # append it to pointer zero and update zero
            zero.next = curr
            zero = zero.next
        elif curr.data == 1:
            
            # if the data of current node is 1, 
            # append it to pointer one and update one
            one.next = curr
            one = one.next
        else:
            
            # if the data of current node is 2, 
            # append it to pointer two and update two
            two.next = curr
            two = two.next
        curr = curr.next

    # combine the three lists
    zero.next = oneD.next if oneD.next else twoD.next
    one.next = twoD.next
    two.next = None

    # updated head 
    head = zeroD.next

    return head

if __name__ == "__main__":
  
    head = Node(1)
    head.next = Node(2)
    head.next.next = Node(2)
    head.next.next.next = Node(1)
    head.next.next.next.next = Node(2)
    head.next.next.next.next.next = Node(0)
    head.next.next.next.next.next.next = Node(2)
    head.next.next.next.next.next.next.next = Node(2)

    head = segregate(head)

    while head is not None:
        print(head.data, end='')
        if(head.next != None):
            print(" -> ", end="");
        head = head.next
    print()
    
C#
using System;

// a linked list node
public class Node {
    
    public int Data;
    public Node Next;

    public Node(int newData) {
        Data = newData;
        Next = null;
    }
}
 
public class GfG {
    public static Node segregate(Node head) {
        if (head == null || head.Next == null)
            return head;

        Node zeroD = new Node(0);
        Node oneD = new Node(0);
        Node twoD = new Node(0);

        Node zero = zeroD, one = oneD, two = twoD;

        // traverse list 
        Node curr = head;
        while (curr != null) {
            if (curr.Data == 0) {
              
                // if the data of current node is 0, 
                // append it to pointer zero and update zero
                zero.Next = curr;
                zero = zero.Next;
            }
            else if (curr.Data == 1) {
              
                // if the data of current node is 1, 
                // append it to pointer one and update one
                one.Next = curr;
                one = one.Next;
            }
            else {
              
                // if the data of current node is 2, 
                // append it to pointer two and update two
                two.Next = curr;
                two = two.Next;
            }
            curr = curr.Next;
        }

        // combine the three lists
        zero.Next = (oneD.Next != null) ? (oneD.Next) : (twoD.Next);
        one.Next = twoD.Next;
        two.Next = null;

        // updated head 
        head = zeroD.Next;

        return head;
    }

  	public static void Main() {
      
        
        Node head = new Node(1);
        head.Next = new Node(2);
        head.Next.Next = new Node(2);
        head.Next.Next.Next = new Node(1);
        head.Next.Next.Next.Next = new Node(2);
        head.Next.Next.Next.Next.Next = new Node(0);
        head.Next.Next.Next.Next.Next.Next = new Node(2);
        head.Next.Next.Next.Next.Next.Next.Next = new Node(2);
        
        head = segregate(head);
        
        while (head != null) {
            Console.Write(head.Data);
            if(head.Next != null){
                Console.Write(" -> ");
            }
            head = head.Next;
        }
        Console.WriteLine();
        
    }
}
JavaScript
class Node {
	
    constructor(newData) {
        this.data = newData;
        this.next = null;
    }
}

function segregate(head) {
    if (!head || !head.next) 
        return head; 
 
    let zeroD = new Node(0); 
    let oneD = new Node(0); 
    let twoD = new Node(0);

    let zero = zeroD, one = oneD, two = twoD; 

    let curr = head; 
    while (curr) { 
        if (curr.data === 0) { 
        	
            // if the data of current node is 0, 
            // append it to pointer zero and update zero
            zero.next = curr; 
            zero = zero.next; 
        } 
        else if (curr.data === 1) { 
        	
            // if the data of current node is 1, 
            // append it to pointer one and update one
            one.next = curr; 
            one = one.next; 
        } 
        else { 
        	
            // if the data of current node is 2, 
            // append it to pointer two and update two
            two.next = curr; 
            two = two.next; 
        } 
        curr = curr.next; 
    } 

    // combine the three lists
    zero.next = (oneD.next) ? (oneD.next) : (twoD.next); 
    one.next = twoD.next; 
    two.next = null; 

    // updated head 
    head = zeroD.next; 

    return head; 
} 

// Driver code
let head = new Node(1);
head.next = new Node(2);
head.next.next = new Node(2);
head.next.next.next = new Node(1);
head.next.next.next.next = new Node(2);
head.next.next.next.next.next = new Node(0);
head.next.next.next.next.next.next = new Node(2);
head.next.next.next.next.next.next.next = new Node(2);

temp = segregate(head);
while (temp) {
    process.stdout.write(temp.data.toString());
    if (temp.next != null) {
        process.stdout.write(" -> ");
    }
    temp = temp.next;
}
console.log();

Output
0 -> 1 -> 1 -> 2 -> 2 -> 2 -> 2 -> 2

The idea is to split the linked list into three separate sublists for 0s, 1s, and 2s using the Dutch National Flag algorithm. We maintain three dummy nodes and corresponding tail pointers to build each sublist during a single traversal. Once the segregation is done, we link these sublists in order: 0s -> 1s -> 2s. This avoids modifying node values and performs the operation in linear time and space.

Steps by step Implementation:

  • Create three dummy nodes to act as the start of separate lists for 0s, 1s, and 2s.
  • Initialize three tail pointers (zero, one, two) that point to the end of each of these sublists.
  • Traverse the original list and based on node value, append it to the respective sublist using tail pointers.
  • After appending a node, move the corresponding tail pointer forward to keep track of the last node.
  • Once traversal is complete, link zero list to one list and then link one list to two list carefully.
  • Ensure the last node of two list points to NULL to terminate the final merged list correctly.
  • Return the head of the combined list which starts from the next of zero dummy node.
C++
#include <iostream>
using namespace std;

class Node {
public:
    int data;
    Node* next;

    Node(int new_data) {
        data = new_data;
        next = nullptr;
    }
};

Node* segregate(Node* head) {
    if (!head || !(head->next)) {
        return head;
    }

    Node* zeroD = new Node(-1);
    Node* oneD = new Node(-1);
    Node* twoD = new Node(-1);

    // Tails for the three lists
    Node* zero = zeroD;
    Node* one = oneD;
    Node* two = twoD;

    // Traverse the original list
    Node* curr = head;
    while (curr) {
        if (curr->data == 0) {
            zero->next = curr;
            zero = zero->next;
        } else if (curr->data == 1) {
            one->next = curr;
            one = one->next;
        } else {
            two->next = curr;
            two = two->next;
        }
        curr = curr->next;
    }

    // Connect the three lists together
    zero->next = oneD->next ? oneD->next : twoD->next;
    one->next = twoD->next;
    two->next = nullptr;

    // New head
    head = zeroD->next;

    delete zeroD;
    delete oneD;
    delete twoD;

    return head;
}

int main() {
    
    Node* head = new Node(1);
    head->next = new Node(2);
    head->next->next = new Node(2);
    head->next->next->next = new Node(1);
    head->next->next->next->next = new Node(2);
    head->next->next->next->next->next = new Node(0);
    head->next->next->next->next->next->next = new Node(2);
    head->next->next->next->next->next->next->next = new Node(2);

    head = segregate(head);

    while (head != nullptr) {
        cout << head->data;
        if(head->next != nullptr){
            cout << " -> ";
        }
        head = head->next;
    }
    cout << "\n";
    return 0;
}
Java
class Node {
    int data;
    Node next;

    Node(int new_data) {
        data = new_data;
        next = null;    }
}

class GfG {

    static Node segregate(Node head) {
        if (head == null || head.next == null) {
            return head;
        }

        // Dummy nodes for three separate lists
        Node zeroD = new Node(-1);
        Node oneD = new Node(-1);
        Node twoD = new Node(-1);

        // Tails for the three lists
        Node zero = zeroD;
        Node one = oneD;
        Node two = twoD;

        // Traverse the original list
        Node curr = head;
        while (curr != null) {
            if (curr.data == 0) {
                zero.next = curr;
                zero = zero.next;
            } else if (curr.data == 1) {
                one.next = curr;
                one = one.next;
            } else {
                two.next = curr;
                two = two.next;
            }
            curr = curr.next;
        }

        // Connect the three lists together
        zero.next = (oneD.next != null) ? oneD.next : twoD.next;
        one.next = twoD.next;
        two.next = null;

        // New head
        head = zeroD.next;

        return head;
    }

    public static void main(String[] args) {

        Node head = new Node(1);
        head.next = new Node(2);
        head.next.next = new Node(2);
        head.next.next.next = new Node(1);
        head.next.next.next.next = new Node(2);
        head.next.next.next.next.next = new Node(0);
        head.next.next.next.next.next.next = new Node(2);
        head.next.next.next.next.next.next.next = new Node(2);

        head = segregate(head);

        while (head != null) {
            System.out.print(head.data);
            if(head.next != null){
                System.out.print(" -> ");
            }
            head = head.next;
        }
        System.out.println();
    }
}
Python
class Node:
    def __init__(self, new_data):
        self.data = new_data
        self.next = None

def segregate(head):
    if not head or not head.next:
        return head

    # Dummy nodes for three separate lists
    zeroD = Node(-1)
    oneD = Node(-1)
    twoD = Node(-1)

    # Tails for the three lists
    zero = zeroD
    one = oneD
    two = twoD

    # Traverse the original list
    curr = head
    while curr:
        if curr.data == 0:
            zero.next = curr
            zero = zero.next
        elif curr.data == 1:
            one.next = curr
            one = one.next
        else:
            two.next = curr
            two = two.next
        curr = curr.next

    # Connect the three lists together
    zero.next = oneD.next if oneD.next else twoD.next
    one.next = twoD.next
    two.next = None

    # New head
    head = zeroD.next

    return head

if __name__ == "__main__":

    head = Node(1)
    head.next = Node(2)
    head.next.next = Node(2)
    head.next.next.next = Node(1)
    head.next.next.next.next = Node(2)
    head.next.next.next.next.next = Node(0)
    head.next.next.next.next.next.next = Node(2)
    head.next.next.next.next.next.next.next = Node(2)

    head = segregate(head)

    while head:
        print(str(head.data), end="")
        if head.next != None:
            print(" -> ", end = "")
        head = head.next
    print()
C#
using System;

class Node {
    public int data;
    public Node next;

    public Node(int new_data) {
        data = new_data;
        next = null;
    }
}

class GfG {

    public static Node segregate(Node head) {
        if (head == null || head.next == null) {
            return head;
        }

        Node zeroD = new Node(-1);
        Node oneD = new Node(-1);
        Node twoD = new Node(-1);

        // Tails for the three lists
        Node zero = zeroD;
        Node one = oneD;
        Node two = twoD;

        // Traverse the original list
        Node curr = head;
        while (curr != null) {
            if (curr.data == 0) {
                zero.next = curr;
                zero = zero.next;
            } else if (curr.data == 1) {
                one.next = curr;
                one = one.next;
            } else {
                two.next = curr;
                two = two.next;
            }
            curr = curr.next;
        }

        // Connect the three lists together
        zero.next = (oneD.next != null) ? oneD.next : twoD.next;
        one.next = twoD.next;
        two.next = null;

        // New head
        head = zeroD.next;

        return head;
    }

    public static void Main() {

        Node head = new Node(1);
        head.next = new Node(2);
        head.next.next = new Node(2);
        head.next.next.next = new Node(1);
        head.next.next.next.next = new Node(2);
        head.next.next.next.next.next = new Node(0);
        head.next.next.next.next.next.next = new Node(2);
        head.next.next.next.next.next.next.next = new Node(2);

        head = segregate(head);

        while (head != null) {
            Console.Write(head.data);
            if(head.next != null)
                Console.Write(" -> ");
            head = head.next;
        }
        Console.WriteLine();
    }
}
JavaScript
function Node(new_data) {
    this.data = new_data;
    this.next = null;
}

function segregate(head) {
    if (!head || !head.next) {
        return head;
    }

    // dummy nodes for three separate lists
    let zeroD = new Node(-1);
    let oneD = new Node(-1);
    let twoD = new Node(-1);

    // tails for the three lists
    let zero = zeroD;
    let one = oneD;
    let two = twoD;

    // traverse the original list
    let curr = head;
    while (curr) {
        if (curr.data === 0) {
            zero.next = curr;
            zero = zero.next;
        } else if (curr.data === 1) {
            one.next = curr;
            one = one.next;
        } else {
            two.next = curr;
            two = two.next;
        }
        curr = curr.next;
    }

    // connect the three lists together correctly
    zero.next = oneD.next ? oneD.next : twoD.next;
    one.next = twoD.next;
    two.next = null; 

    return zeroD.next;
}

// Driver code
let head = new Node(1);
head.next = new Node(2);
head.next.next = new Node(2);
head.next.next.next = new Node(1);
head.next.next.next.next = new Node(2);
head.next.next.next.next.next = new Node(0);
head.next.next.next.next.next.next = new Node(2);
head.next.next.next.next.next.next.next = new Node(2);

head = segregate(head);

let temp = head;
while (temp) {
    process.stdout.write(temp.data.toString());
    if (temp.next != null) {
        process.stdout.write(" -> ");
    }
    temp = temp.next;
}
console.log();

Output
0 -> 1 -> 1 -> 2 -> 2 -> 2 -> 2 -> 2
Comment