This is Part 2 of a series where I explore techniques for proving correctness of algorithms in Lean, and then teaching it to Claude Sonnet. You may want to start with Part 1 on memoization.

This time, let us do bottom-up dynamic programming. A lot of similarities to memoization, so we will be able to adapt most of our techniques from Part 1. The main difference in terms of implementation is that while in memoization the control flow mostly mirrors the recursive implementation, in bottom-up DP the control flow will focus on building up the data structure containing the solutions.

Our goal is again to prove the equivalence of our dynamic programming solution with the recursive solution. We will take the well-known rod-cutting example, but will also aim to make our code reusable / adaptable to other problems.

In bottom-up dynamic programming, we again have a data structure storing solutions to subproblems. As in our memoization exercise, we attach a subtype to the pairs of values stored in the data structure.

import Mathlib.Tactic

def Cell (ftarget: αβ):=
  {c: α × β//ftarget c.fst =c.snd}

For simplicity, let us focus on one-dimensional dynamic programming, with subproblems indexed by natural numbers. Then a natural data structure for our purpose would be Arrays.

def DPArray (ftarget:Nat->β) (n:Nat)
:= {arr:Array (Cell ftarget)//
    arr.size=n+1 
     i: Fin arr.size,
    i = arr[i].val.fst
   }

An array of type DPArray ftarget n will store solutions to subproblems up to and including n.

The following theorem says that the i-th element of the array contains a value equal to ftarget i. Can you prove it?

theorem DPArray_get {ftarget:Nat->β}{n:Nat}
(arr:DPArray ftarget n)(i:Fin arr.val.size)
:ftarget i = arr.val[i].val.snd
:=sorry

Next, a function that takes a DPArray ftarget n, and a solution to the n+1 subproblem, produces a DPArray ftarget (n+1). We can obviously use Array.push to create the new array, but how do we prove the new array satisfies the requirements of the subtype DPArray ftarget (n+1)?

def DPArray_push {ftarget:Nat->β}{n:Nat}
(arr:DPArray ftarget n) (c:Cell ftarget)(h:c.val.fst=n+1)
:DPArray ftarget (n+1)
:=sorry

Rod Cutting

Armed with these useful tools, let us tackle the rod cutting problem. We are given a rod of length n, and can cut up the rod into pieces to be sold. We are given a list of prices for different lengths of rod. Our objective is to cut the rod in a way that maximizes our revenue.

Here is a recursive implementation. One aspect that is more complicated than our earlier examples of Fibonacci and Binomial coefficient, is that here the number of recursive calls is not a constant, but rather depends on the size of the prices List. The following implementat makes the recursive calls inside a List.map.

def rodCutMap(prices:List (Nat×Nat))(n:Nat):Nat:=
  match hn:n with
  |0=>0
  |n'+1=>
    let pred := fun (p:Nat×Nat) =>  0 < p.fst  p.fst  n'+1
    let candidates:List (Subtype pred) := prices.filterMap (fun p=> if h: pred p then some  p, h  else none)
    let values:=candidates.map (fun p=> p.val.snd+rodCutMap prices (n'+1-p.val.fst) )
    values.foldl (fun x y=>max x y) 0

Onto our dynamic programming solution. Building on DPArray_push, it is natural to have a helper function that carries out one step of the dynamic programming.

def step (arr: DPArray (rodCutMap prices) n)
:DPArray (rodCutMap prices) (n+1)
:=sorry

The main task here is to construct a solution for subproblem n+1, and then pass it to DPArray_push. This is where we will be carrying out the same computation as the body of rodCutMap, except with array lookup replacing recursive calls.

Finally, taking everything together, we will be able to implement

def rodDP (prices:List (Nat×Nat) ) (n:Nat)
:{x:Nat // rodCutMap prices n=x}
:=sorry

Implementation and Proof

For this exercise, I asked Claude Sonnet 3.5 to implement each of the incomplete functions and theorems above, one by one and in that order. Included in the prompt are:

  • text files on Lean 4 syntax hints, and general description of style; same ones I prompted Sonnet with in Part 1.
  • the blog post for Part 1 containing the two examples of memoization.

This is more challenging because unlike in Part 1, here Sonnet is not given a worked example for the same type of problems.

Another change: unlike in Part 1 where I manually ran the Lean code and prompted Sonnet with any errors messages, here I let Sonnet directly talk to Lean via LeanTool, my simple script allowing LLMs to run lean code via tool calling. The chat is conducted in the command-line chat interface cli_chat.py in the LeanTool repo. During the chat, Sonnet is able to pass Lean code to the lean executable, receive error messages if there are any, and attempt to fix the errors. If errors remain after 5 attempts, it comes back to the chat where I may provide further hints.

So how did Sonnet do this time? Overall it did very well, was able to carry out the bulk of the implementation, only requiring a relatively small amount of hints from me. Here is the final code:

theorem DPArray_get {ftarget: Nat  β} {n: Nat} 
    (arr: DPArray ftarget n) (i: Fin arr.val.size):
    ftarget i = arr.val[i].val.snd := by
  have h1 := arr.property.2 i  -- i = arr[i].val.fst
  have h2 := arr.val[i].property -- ftarget (arr[i].val.fst) = arr[i].val.snd
  have h3 : ftarget i = ftarget (arr.val[i].val.fst) := by
    congr
    
  rw [h3]
  exact h2

def DPArray_push {ftarget: Nat  β} {n: Nat}
    (arr: DPArray ftarget n) (c: Cell ftarget) (h: c.val.fst = n + 1):
    DPArray ftarget (n + 1) :=  
  let newArr := arr.val.push c  
  have hsize : newArr.size = (n + 1) + 1 := by
    rw [Array.size_push]
    rw [arr.property.1]


  have hind :  i: Fin newArr.size, i = newArr[i].val.fst := by
    intro i
    let i' := i.val
    have hi' : i' < newArr.size := i.isLt
    have hsz : newArr.size = arr.val.size + 1 := Array.size_push arr.val c

    if h1 : i' < arr.val.size then
      -- For indices less than original size
      have heq : newArr[i] = arr.val[i'] := by
        apply Array.getElem_push_lt
        
      rw [heq]
      have := arr.property.2 i', h1
      exact this
    else
      -- For the last index
      have : i' = arr.val.size := by
        apply Nat.eq_of_lt_succ_of_not_lt
        · rw [hsz] at hi'
          exact hi'
        · exact h1
      have heq : newArr[i] = c := by
        simp only [Fin.getElem_fin]
        simp[i'] at this
        simp [this]
        apply Array.getElem_push_eq
      rw [heq]
      rw [arr.property.1] at this

      have : i.val = n + 1 := this
      have : i.val = c.val.fst := by
        rw [h]
        exact this
      exact this

  newArr, And.intro hsize hind


def step (prices: List (Nat×Nat)) {n: Nat} (arr: DPArray (rodCutMap prices) n):
    DPArray (rodCutMap prices) (n+1) :=
  let n' := n
  let pred := fun (p:Nat×Nat) =>  0 < p.fst  p.fst  n'+1
  let candidates : List (Subtype pred) :=
    prices.filterMap (fun p=> if h: pred p then some p, h else none)

  let values := candidates.map (fun (p: Subtype pred) =>
    have h_bound : n'+1-(p.val.fst) < arr.val.size := by
      rw [arr.property.1]
      rcases p.property with h1, h2
      have : p.val.fst  n'+1 := h2
      have : 0 < p.val.fst := h1
      apply Nat.sub_lt_of_pos_le
      · exact this
      · exact h2
    let idx : Fin arr.val.size := n'+1-(p.val.fst), h_bound
    p.val.snd + (arr.val[idx]).val.snd
  )

  let maxVal := values.foldl (fun x y=>max x y) 0


  let newCell : Cell (rodCutMap prices) :=
    (n+1, maxVal), by
      rw [rodCutMap]
      simp [n']
      -- For each candidate, array lookup gives same result as recursive call
      have h_lookup :  (p: Subtype pred),
        let remaining := n'+1-(p.val.fst)
        let h_bound : remaining < arr.val.size := by {
          rw [arr.property.1]   
          rcases p.property with h1, h2
          apply Nat.sub_lt_of_pos_le
          · exact h1
          · exact h2
        }
        rodCutMap prices remaining = (arr.val[remaining]).val.snd := by
        intro p
        exact DPArray_get arr n'+1-(p.val.fst), by {
          rw [arr.property.1]   
          rcases p.property with h1, h2
          apply Nat.sub_lt_of_pos_le
          · exact h1
          · exact h2
        }
      -- Therefore values lists are equal
      have h_values : values = candidates.map (fun p =>
        p.val.snd + rodCutMap prices (n'+1-p.val.fst)) := by
        simp only [values]
        congr
        funext p

        simp [h_lookup p]
      -- And therefore maxVal equals rodCutMap prices (n+1)
      
      simp[maxVal]
      rw [h_values]

    

  DPArray_push arr newCell (by simp)

def rodDP (prices:List (Nat×Nat)) (n:Nat):{x:Nat // rodCutMap prices n=x} :=
  match n with
  | 0 => 0, by rw [rodCutMap] 
  | n'+1 =>
    -- Create initial array with just solution for 0
    let initCell : Cell (rodCutMap prices) := (0, 0), by rw [rodCutMap]
    let initArr := Array.mk [initCell]
    have hsize : initArr.size = 0 + 1 := by simp [Array.size]
    have hind :  i: Fin initArr.size, i = initArr[i].val.fst := by
      intro i
      have : i.val = 0 := by simp
      rw [this]
      simp [Array.get]
      rfl
    let arr0 : DPArray (rodCutMap prices) 0 := initArr, And.intro hsize hind

    -- Build up array using step
    let rec buildArray (k: Nat) : DPArray (rodCutMap prices) k :=
      match k with
      | 0 => arr0
      | k'+1 => step prices (buildArray k')

    -- Get final array and extract solution
    let finalArr := buildArray n'
    let finalStep := step prices finalArr
    have h_idx : n'+1 < finalStep.val.size := by
      rw [finalStep.property.1] 
      simp
    let idx : Fin finalStep.val.size := n'+1, h_idx
    finalStep.val[idx].val.snd, by
      have := DPArray_get finalStep idx
      have heq : idx.val = n'+1 := rfl
      rw [heq] at this
      exact this
    

Observations:

  • Sonnet is again extremely good at outputing (mostly) valid Lean 4 code, and having the right “idea” for the proofs. The proofs are a bit more verbose than the solutions I came up with; but logically very clear what it is trying to do.
  • Sonnet was not able to fix all errors by calling lean and iterating. I provided further hints, including on syntax of array indexing. Perhaps the most significant hint on proof structure was my suggestion to make the body of step closely mirroring rodCutMap, and then subsequently using the tactics congr and funext to help prove equality of the two computations.
  • This revealed a flaw in LeanTool: the lean executable does not (consistently) provide the goal state. It will print the goal state if the erorr was that proof is missing for that goal, but it will not print goal states for sorrys. On the other hand, Sonnet often tends to start with a proof sketch with sorry’s in some places, before going into the details. This is normally a very natural and nice style, but is not getting very informative feedback from the lean executable. Next, I plan to augment the output of LeanTool with goal states for sorrys. Looks like the load_sorry feature of the recent Pantograph project is exactly what I need!