Proving Dynamic Programming with Sonnet (Part 2)
This is Part 2 of a series where I explore techniques for proving correctness of algorithms in Lean, and then teaching it to Claude Sonnet. You may want to start with Part 1 on memoization.
This time, let us do bottom-up dynamic programming. A lot of similarities to memoization, so we will be able to adapt most of our techniques from Part 1. The main difference in terms of implementation is that while in memoization the control flow mostly mirrors the recursive implementation, in bottom-up DP the control flow will focus on building up the data structure containing the solutions.
Our goal is again to prove the equivalence of our dynamic programming solution with the recursive solution. We will take the well-known rod-cutting example, but will also aim to make our code reusable / adaptable to other problems.
In bottom-up dynamic programming, we again have a data structure storing solutions to subproblems. As in our memoization exercise, we attach a subtype to the pairs of values stored in the data structure.
import Mathlib.Tactic
def Cell (ftarget: α→β):=
{c: α × β//ftarget c.fst =c.snd}
For simplicity, let us focus on one-dimensional dynamic programming, with subproblems indexed by natural numbers. Then a natural data structure for our purpose would be Arrays.
def DPArray (ftarget:Nat->β) (n:Nat)
:= {arr:Array (Cell ftarget)//
arr.size=n+1 ∧
∀ i: Fin arr.size,
i = arr[i].val.fst
}
An array of type DPArray ftarget n
will store solutions to subproblems up to and including n
.
The following theorem says that the i-th element of the array contains a value equal to ftarget i
.
Can you prove it?
theorem DPArray_get {ftarget:Nat->β}{n:Nat}
(arr:DPArray ftarget n)(i:Fin arr.val.size)
:ftarget i = arr.val[i].val.snd
:=sorry
Next, a function that takes a DPArray ftarget n
, and a solution to the n+1
subproblem,
produces a DPArray ftarget (n+1)
. We can obviously use Array.push
to create the new array, but how do we prove the new array satisfies the requirements of the subtype DPArray ftarget (n+1)
?
def DPArray_push {ftarget:Nat->β}{n:Nat}
(arr:DPArray ftarget n) (c:Cell ftarget)(h:c.val.fst=n+1)
:DPArray ftarget (n+1)
:=sorry
Rod Cutting
Armed with these useful tools, let us tackle the rod cutting problem.
We are given a rod of length n
, and can cut up the rod into pieces to be sold.
We are given a list of prices for different lengths of rod. Our objective is to
cut the rod in a way that maximizes our revenue.
Here is a recursive implementation.
One aspect that is more complicated than our earlier examples of Fibonacci and Binomial coefficient,
is that here the number of recursive calls is not a constant,
but rather depends on the size of the prices List.
The following implementat makes the recursive calls inside a List.map
.
def rodCutMap(prices:List (Nat×Nat))(n:Nat):Nat:=
match hn:n with
|0=>0
|n'+1=>
let pred := fun (p:Nat×Nat) => 0 < p.fst ∧ p.fst ≤ n'+1
let candidates:List (Subtype pred) := prices.filterMap (fun p=> if h: pred p then some ⟨ p, h⟩ else none)
let values:=candidates.map (fun p=> p.val.snd+rodCutMap prices (n'+1-p.val.fst) )
values.foldl (fun x y=>max x y) 0
Onto our dynamic programming solution.
Building on DPArray_push
, it is natural to have a helper function that
carries out one step of the dynamic programming.
def step (arr: DPArray (rodCutMap prices) n)
:DPArray (rodCutMap prices) (n+1)
:=sorry
The main task here is to construct a solution for subproblem n+1
, and then pass it
to DPArray_push
. This is where we will be carrying out the same computation as the body of rodCutMap, except with array lookup replacing recursive calls.
Finally, taking everything together, we will be able to implement
def rodDP (prices:List (Nat×Nat) ) (n:Nat)
:{x:Nat // rodCutMap prices n=x}
:=sorry
Implementation and Proof
For this exercise, I asked Claude Sonnet 3.5 to implement each of the incomplete functions and theorems above, one by one and in that order. Included in the prompt are:
- text files on Lean 4 syntax hints, and general description of style; same ones I prompted Sonnet with in Part 1.
- the blog post for Part 1 containing the two examples of memoization.
This is more challenging because unlike in Part 1, here Sonnet is not given a worked example for the same type of problems.
Another change: unlike in Part 1 where I manually ran the Lean code and prompted Sonnet with any errors messages, here I let Sonnet directly talk to Lean via LeanTool, my simple script allowing LLMs to run lean code via tool calling.
The chat is conducted in the command-line chat interface cli_chat.py
in the LeanTool repo.
During the chat, Sonnet is able to pass Lean code to the lean executable, receive error messages if there are any, and attempt to fix the errors. If errors remain after 5 attempts, it comes back to the chat where I may provide further hints.
So how did Sonnet do this time? Overall it did very well, was able to carry out the bulk of the implementation, only requiring a relatively small amount of hints from me. Here is the final code:
theorem DPArray_get {ftarget: Nat → β} {n: Nat}
(arr: DPArray ftarget n) (i: Fin arr.val.size):
ftarget i = arr.val[i].val.snd := by
have h1 := arr.property.2 i -- i = arr[i].val.fst
have h2 := arr.val[i].property -- ftarget (arr[i].val.fst) = arr[i].val.snd
have h3 : ftarget i = ftarget (arr.val[i].val.fst) := by
congr
rw [h3]
exact h2
def DPArray_push {ftarget: Nat → β} {n: Nat}
(arr: DPArray ftarget n) (c: Cell ftarget) (h: c.val.fst = n + 1):
DPArray ftarget (n + 1) :=
let newArr := arr.val.push c
have hsize : newArr.size = (n + 1) + 1 := by
rw [Array.size_push]
rw [arr.property.1]
have hind : ∀ i: Fin newArr.size, i = newArr[i].val.fst := by
intro i
let i' := i.val
have hi' : i' < newArr.size := i.isLt
have hsz : newArr.size = arr.val.size + 1 := Array.size_push arr.val c
if h1 : i' < arr.val.size then
-- For indices less than original size
have heq : newArr[i] = arr.val[i'] := by
apply Array.getElem_push_lt
rw [heq]
have := arr.property.2 ⟨i', h1⟩
exact this
else
-- For the last index
have : i' = arr.val.size := by
apply Nat.eq_of_lt_succ_of_not_lt
· rw [hsz] at hi'
exact hi'
· exact h1
have heq : newArr[i] = c := by
simp only [Fin.getElem_fin]
simp[i'] at this
simp [this]
apply Array.getElem_push_eq
rw [heq]
rw [arr.property.1] at this
have : i.val = n + 1 := this
have : i.val = c.val.fst := by
rw [h]
exact this
exact this
⟨newArr, And.intro hsize hind⟩
def step (prices: List (Nat×Nat)) {n: Nat} (arr: DPArray (rodCutMap prices) n):
DPArray (rodCutMap prices) (n+1) :=
let n' := n
let pred := fun (p:Nat×Nat) => 0 < p.fst ∧ p.fst ≤ n'+1
let candidates : List (Subtype pred) :=
prices.filterMap (fun p=> if h: pred p then some ⟨p, h⟩ else none)
let values := candidates.map (fun (p: Subtype pred) =>
have h_bound : n'+1-(p.val.fst) < arr.val.size := by
rw [arr.property.1]
rcases p.property with ⟨h1, h2⟩
have : p.val.fst ≤ n'+1 := h2
have : 0 < p.val.fst := h1
apply Nat.sub_lt_of_pos_le
· exact this
· exact h2
let idx : Fin arr.val.size := ⟨n'+1-(p.val.fst), h_bound⟩
p.val.snd + (arr.val[idx]).val.snd
)
let maxVal := values.foldl (fun x y=>max x y) 0
let newCell : Cell (rodCutMap prices) :=
⟨(n+1, maxVal), by
rw [rodCutMap]
simp [n']
-- For each candidate, array lookup gives same result as recursive call
have h_lookup : ∀ (p: Subtype pred),
let remaining := n'+1-(p.val.fst)
let h_bound : remaining < arr.val.size := by {
rw [arr.property.1]
rcases p.property with ⟨h1, h2⟩
apply Nat.sub_lt_of_pos_le
· exact h1
· exact h2
}
rodCutMap prices remaining = (arr.val[remaining]).val.snd := by
intro p
exact DPArray_get arr ⟨n'+1-(p.val.fst), by {
rw [arr.property.1]
rcases p.property with ⟨h1, h2⟩
apply Nat.sub_lt_of_pos_le
· exact h1
· exact h2
}⟩
-- Therefore values lists are equal
have h_values : values = candidates.map (fun p =>
p.val.snd + rodCutMap prices (n'+1-p.val.fst)) := by
simp only [values]
congr
funext p
simp [h_lookup p]
-- And therefore maxVal equals rodCutMap prices (n+1)
simp[maxVal]
rw [h_values]
⟩
DPArray_push arr newCell (by simp)
def rodDP (prices:List (Nat×Nat)) (n:Nat):{x:Nat // rodCutMap prices n=x} :=
match n with
| 0 => ⟨0, by rw [rodCutMap]⟩
| n'+1 =>
-- Create initial array with just solution for 0
let initCell : Cell (rodCutMap prices) := ⟨(0, 0), by rw [rodCutMap]⟩
let initArr := Array.mk [initCell]
have hsize : initArr.size = 0 + 1 := by simp [Array.size]
have hind : ∀ i: Fin initArr.size, i = initArr[i].val.fst := by
intro i
have : i.val = 0 := by simp
rw [this]
simp [Array.get]
rfl
let arr0 : DPArray (rodCutMap prices) 0 := ⟨initArr, And.intro hsize hind⟩
-- Build up array using step
let rec buildArray (k: Nat) : DPArray (rodCutMap prices) k :=
match k with
| 0 => arr0
| k'+1 => step prices (buildArray k')
-- Get final array and extract solution
let finalArr := buildArray n'
let finalStep := step prices finalArr
have h_idx : n'+1 < finalStep.val.size := by
rw [finalStep.property.1]
simp
let idx : Fin finalStep.val.size := ⟨n'+1, h_idx⟩
⟨finalStep.val[idx].val.snd, by
have := DPArray_get finalStep idx
have heq : idx.val = n'+1 := rfl
rw [heq] at this
exact this
⟩
Observations:
- Sonnet is again extremely good at outputing (mostly) valid Lean 4 code, and having the right “idea” for the proofs. The proofs are a bit more verbose than the solutions I came up with; but logically very clear what it is trying to do.
- Sonnet was not able to fix all errors by calling lean and iterating. I provided further
hints, including on syntax of array indexing. Perhaps the most significant hint on proof structure was my suggestion to make the body of
step
closely mirroringrodCutMap
, and then subsequently using the tacticscongr
andfunext
to help prove equality of the two computations. - This revealed a flaw in LeanTool: the lean executable does not (consistently) provide the goal state. It will print the goal state if the erorr was that proof is missing for that goal,
but it will not print goal states for
sorry
s. On the other hand, Sonnet often tends to start with a proof sketch with sorry’s in some places, before going into the details. This is normally a very natural and nice style, but is not getting very informative feedback from the lean executable. Next, I plan to augment the output of LeanTool with goal states for sorrys. Looks like theload_sorry
feature of the recent Pantograph project is exactly what I need!