Each nn.Linear(d_in, d_out) is replaced by
two factors A∈RR×din and
B∈Rdout×R. At inference,
calling set_rank(r) simply slices the
shared factors:
- Ar=A[:r,:]
- Br=B[:,:r]
- y=Br(Arx)+b
Because rank r uses a prefix of the same
factors as rank r+1, the image spaces are nested:
Im(W1)⊆Im(W2)⊆⋯⊆Im(WR)
Training jointly couples ranks via an uncertainty-aware objective
(learned log-variances per rank) plus a curriculum that introduces
low ranks gradually. The result is a single set of weights that
behaves as a continuum of models — anywhere from a 50% FLOPs
reduction at a 5pp accuracy drop, to interpolating to ranks the
network never even saw at training time.
class DynamicLowRankLinear(nn.Module):
def __init__(self, in_features, out_features, max_rank, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.max_rank = max_rank
self.active_rank = max_rank
self.A = nn.Linear(in_features, max_rank, bias=False)
self.B = nn.Linear(max_rank, out_features, bias=bias)
def set_rank(self, rank):
self.active_rank = max(1, min(rank, self.max_rank))
def forward(self, x):
r = self.active_rank
A_r = self.A.weight[:r, :] # (r, in_features)
B_r = self.B.weight[:, :r] # (out_features, r)
h = F.linear(x, A_r)
return F.linear(h, B_r, self.B.bias)