Sunday, February 26, 2023

Chatgpt on "Riffing on triplet sum to a target"

(Best viewed on a larger than phone screen).

 

Last year I blogged on a taking a problem and solving it in multiple ways, just for the fun of it.

Triplet sum  is a problem solved by many, and I had my own description of the problem so I decided to investigate it using ChatGPT and see what it said.

In the following I chose not to use screen grabs of the ChatGPT session, but instead copy text input as me 1 through 5, and responses as c 1 through 6. I use Python comments for my views on its responses.

Familiarise yourself with my earlier blog post to better understand the following interaction:

# -*- coding: utf-8 -*-
 
"""
triplet_sum_chatgpt.py

Investigation of Chatgpt and its solution to:
    http://paddy3118.blogspot.com/2022/10/riffing-on-triplet-sum-to-target.html

Created on Sun Feb 26 10:26:41 2023

@author: paddy
"""

#%% BEGIN Chatgpt interaction

me1 = """
Given a list of integers and a target integer, return all "normalised" triplets
of numbers from that list that sum to the target value.
Normalised triplets are a tuple of three values in sorted order,
i.e. (t0, t1, t2) where t0 <= t1 <= t2.
"""

c1 = """
To solve this problem, we can start by sorting the given list of integers. Then,
we can use a nested loop to iterate over all possible pairs of numbers from the
list, and for each pair, we can use binary search to find the third number that
would make the sum equal to the target integer. We can then check if the
triplet is in sorted order and add it to the result if it is.

Here's the Python code that implements this approach:
"""

def find_triplets(arr, target):
    arr.sort()
    n = len(arr)
    res = []
    for i in range(n - 2):
        # skip duplicates
        if i > 0 and arr[i] == arr[i-1]:
            continue
        j = i + 1
        k = n - 1
        while j < k:
            s = arr[i] + arr[j] + arr[k]
            if s == target:
                # check if triplet is in sorted order
                if j == i + 1 or arr[j] != arr[j-1]:
                    res.append((arr[i], arr[j], arr[k]))
                j += 1
                k -= 1
            elif s < target:
                j += 1
            else:
                k -= 1
    return res

c2 = """
This code has a time complexity of O(n^2), where n is the length of the input
list. The sorting step takes O(n log n) time, and the nested loop takes O(n^2)
time. The binary search inside the loop takes O(log n) time, so the overall
time complexity is O(n^2 log n).

Note that this code assumes that there are no duplicates in the input list.
If there are duplicates, we need to skip them in the loop to avoid generating
duplicate triplets. The code handles this by checking if the current number is
equal to the previous number, and skipping it if it is.
"""

#%% My follow-up

# As in original blog post:
def tests1(funcs, args=None):
    if args is None:
        args = [([2, 0, 1, 3], 4),
                ([0,1,3,2,4,5,6], 11),
                ([0,1,2], 6),
                ([0,1,2], 3),
                ([0,1,2], 1),
                ([0,1], 2),
                ([0,1], 1),
                ([0], 0),
                ([], 99),
                ]
    seperator, indent = ('---\n', '  ') if len(funcs) > 1 else ('', '')
    lines = []
    for numbers, target in args:
        line = ''
        for func in funcs:
            line += f"{indent}{func.__name__}({numbers}, {target}) = "
            try:
                answer = func(numbers, target)
                line += str(answer)
            except Exception as ex:
                line += str(ex)
            line += '\n'
        lines.append(line)
    print(seperator.join(lines))

tests1([find_triplets])

# # The above prints the correct results of:
# find_triplets([2, 0, 1, 3], 4) = [(0, 1, 3)]
# find_triplets([0, 1, 3, 2, 4, 5, 6], 11) = [(0, 5, 6), (1, 4, 6), (2, 3, 6), (2, 4, 5)]
# find_triplets([0, 1, 2], 6) = []
# find_triplets([0, 1, 2], 3) = [(0, 1, 2)]
# find_triplets([0, 1, 2], 1) = []
# find_triplets([0, 1], 2) = []
# find_triplets([0, 1], 1) = []
# find_triplets([0], 0) = []
# find_triplets([], 99) = []

#%% Time complexities of blog answers

me2 = """
What is the time complexity of: def triplet_sum1(numbers: list[int], target: int) -> list[tuple[int]]:
    "triplet-sum by filtering all triplet combinations"
    return [triplet for triplet in combinations(sorted(numbers), 3)
            if sum(triplet) == target]
"""

c3 = """
The time complexity of the triplet_sum1 function is O(n^3), where n is the
length of the input list.

The function uses the combinations function from the itertools module to
generate all possible triplets of numbers from the input list. The combinations
function generates all possible combinations of k elements from the input list,
where k is the second argument to the function. In this case, k is 3, so
combinations(sorted(numbers), 3) generates all possible triplets of numbers
from the sorted input list. The number of possible triplets of n elements is n
choose 3, which is equal to (n * (n-1) * (n-2)) / 6. Therefore, the
combinations function generates O(n^3) triplets.

The generated triplets are then filtered using a list comprehension, which
checks if the sum of each triplet is equal to the target integer. Checking the
sum of a triplet takes constant time, so the filtering step takes O(n^3) time.

Therefore, the overall time complexity of the triplet_sum1 function is O(n^3).
This is slower than the approach used in the previous answer, which has a time
complexity of O(n^2 log n).
"""

# The answer looks good. triplet_sum1 is written to obviousely generate the
# correct output.


me3 = """
what is the time complexity of: from itertools import islice

def triplet_sum3(numbers: list[int], target: int) -> list[tuple[int]]:
    "Triplet-sum with short-circuited combinations and less copying of numbers."
    numbers.sort()
    answers = []

    for i, item in enumerate(numbers):
        isum = item
        if isum <= target:
            for j, jtem in enumerate(islice(numbers, i+1, None), i+1):
                jsum = isum + jtem
                if jsum <= target:
                    for k, ktem in enumerate(islice(numbers, j+1, None), j+1):
                        ksum = jsum + ktem
                        if ksum == target:
                            answers.append((item, jtem, ktem))
                        elif ksum > target:
                            break
                else:
                    break # Skip as jsum > target
        else:
            break # Skip as item > target

    return answers
"""

c4 = """
The time complexity of the triplet_sum3 function is O(n^2), where n is the
length of the input list.

The function uses a nested loop to generate all possible triplets of numbers
from the sorted input list. The outer loop iterates over each item in the input
list, which takes O(n) time. The inner loop uses the islice function from the
itertools module to iterate over the remaining items in the input list after
the current item, which takes O(n) time. The innermost loop does the same for
the remaining items after the second item in the triplet, which again takes
O(n) time.

Therefore, the overall time complexity of the function is O(n^3). However, the
function uses short-circuiting to avoid iterating over unnecessary
combinations. If the sum of the first item and the second item in a triplet is
greater than the target, the function skips the remaining items in the inner
loop and moves on to the next item in the outer loop. Similarly, if the sum of
the first item is greater than the target, the function skips the remaining
items in the outer loop. This reduces the number of iterations of the inner
loops and makes the overall time complexity O(n^2).

Note that the function sorts the input list before generating the triplets.
Sorting takes O(n log n) time, which is dominated by the O(n^2) time complexity
of the nested loops. Therefore, the overall time complexity of the function is
O(n^2).
"""

# The above is wrong, in that all triplets are not generated then tested as
# it states in the second paragraph.
# The calculated time complexity looks plausible.

# Let's try it with my final function that can do more than triplets

me4 = '''
what is the time complexity of: def triplet_sum4(numbers: list[int], target: int, n_tuple: int=3) \
        -> list[tuple[int]]:
    """
    Generate triplets from numbers that sum to target in sorted order.
    Indexes the sorted numbers without otherwise copying them.

    n_tuple=3 for triplets, 2 for pairs, etc, for any +ve int
    """
    numbers.sort()
    n_len = len(numbers)
    answers = []
    indices = [-1] * n_tuple  # indices for triples; indices[0] must be -1
    incrementing = 0          # index to increment at any time
    accum = [None] * n_tuple  # (partial) sums in accumulator

    while True:
        indices[incrementing] += 1
        if indices[incrementing] >= n_len:
            # maxed out this index so go back one
            incrementing -= 1
            if incrementing < 0:
                break
            continue  # continue incrementing 'outer' index
        if incrementing < n_tuple - 1:
            indices[incrementing+1] = indices[incrementing]
            acc = accum[incrementing] = numbers[indices[incrementing]] + (
                                    0 if incrementing == 0
                                    else accum[incrementing - 1])
            if acc > target:
                # maxed out this index so go back one
                incrementing -= 1
                if incrementing < 0:
                    break
                continue

            incrementing += 1  # next index
        elif incrementing == n_tuple - 1:
            # updated last index
            acc = numbers[indices[incrementing]] + (
                    0 if incrementing == 0
                    else accum[incrementing - 1])
            if acc == target:
                answers.append(tuple(numbers[i] for i in indices))
    return answers
'''

c5 = """
The time complexity of the triplet_sum4 function is O(n^(k-1)), where n is the
length of the input list and k is the size of the tuple being generated.

The function generates tuples by incrementing indices into the sorted input
list. It uses a single loop that iterates until all tuples of the desired size
have been generated. Each iteration increments one of the indices and updates
an accumulator array that keeps track of the sum of the numbers indexed so far.
If the sum of the numbers indexed so far exceeds the target, the function
backtracks to the previous index and continues incrementing from there. If the
last index is incremented and the sum of the numbers indexed equals the target,
the function adds the tuple to the list of answers.

The time complexity of each iteration of the loop is O(n), since each index can
be incremented at most n times before the loop terminates. Therefore, the time
complexity of generating all tuples of size k is O(n^k).

In the case of generating triplets (k=3), the time complexity is O(n^2), which
is lower than the time complexity of the triplet_sum3 function (also O(n^2))
because the triplet_sum4 function generates fewer tuples. For larger values of
k, the time complexity of triplet_sum4 can become very large, and it may be
more efficient to use a different algorithm that avoids generating all possible
tuples.
"""

# Its description of the workings of triplet_sum4 is wrong too!
# All possible tuples are not generated then tested - it short circuits tuple
# generation when it knows it will exceed the target.
# I am unsure of its complexity calculation.

# I like that it compared, (one), previous answers.

# %% What's fastest?

me5 = """
for finding triplets from the integers in range zero to 665 inclusive, that sum
to six; which function is the fastest?
"""

c6 = """
To find the fastest function for finding triplets from integers in range zero
to 665 inclusive that sum to six, we need to test the performance of each
function and compare their execution times. Here's an example test code:


import timeit

def test_function(func):
    numbers = list(range(666))
    target = 6
    result = func(numbers, target)
    assert result == [(0, 2, 4), (0, 3, 3), (1, 1, 4), (1, 2, 3)]
    print(f"{func.__name__} took {timeit.timeit(lambda: func(numbers, target), number=1000):.6f} seconds")

test_function(triplet_sum1)
test_function(triplet_sum2)
test_function(triplet_sum3)
test_function(triplet_sum4)


This code creates a list of integers from 0 to 665, sets the target value to 6,
calls each function with the same arguments, and checks if the output is
correct. Then, it measures the execution time of each function with 1000
iterations and prints the result.

Note that the execution times may vary depending on the machine and the version
of Python used.

Based on the results of the test, the fastest function is likely to be
triplet_sum4, followed by triplet_sum3, triplet_sum2, and triplet_sum1,
in that order. However, the performance difference between the functions is
relatively small for this input size, and other factors such as code
readability and maintainability may also be important considerations.
"""

# Whoops! The context established integers are used once from the input array.
# In c6 result assumes input integers can be used twice, when they appear only
# once in the input!

# Functions from prior context are first its generated `find_triplets`, then my
# `triplet_sum` 1, 3, and 4. This last answer invents triplet_sum2 and forgets
# its own find_triplets.


Conclusion

ChatGPT gives no feedback on its level of certainty. It will be wrong with the same authority as its truer statements, Some errors can be caught, but others will not be if you are relying on Chatgpt to be more knowlegeable than yourself.