Monotone, Bi-Lipschitz, and Polyak-Łojasiewicz Networks

Ruigang Wang    Krishnamurthy (Dj) Dvijotham    Ian R. Manchester
Abstract

This paper presents a new bi-Lipschitz invertible neural network, the BiLipNet, which has the ability to smoothly control both its Lipschitzness (output sensitivity to input perturbations) and inverse Lipschitzness (input distinguishability from different outputs). The second main contribution is a new scalar-output network, the PLNet, which is a composition of a BiLipNet and a quadratic potential. We show that PLNet satisfies the Polyak-Łojasiewicz condition and can be applied to learn non-convex surrogate losses with a unique and efficiently-computable global minimum. The central technical element in these networks is a novel invertible residual layer with certified strong monotonicity and Lipschitzness, which we compose with orthogonal layers to build the BiLipNet. The certification of these properties is based on incremental quadratic constraints, resulting in much tighter bounds than can be achieved with spectral normalization. Moreover, we formulate the calculation of the inverse of a BiLipNet – and hence the minimum of a PLNet – as a series of three-operator splitting problems, for which fast algorithms can be applied.

Machine Learning, ICML

1 Introduction

In many applications, it is desirable to learn neural networks with certified input-output behaviors, i.e., certain properties that are guaranteed by design. For example, Lipschitz-bounded networks have proven to be beneficial for stabilizing of generative adversarial network (GAN) training (Arjovsky et al., 2017; Gulrajani et al., 2017), certifying robustness against adversarial attacks (Tsuzuku et al., 2018; Singla & Feizi, 2021; Zhang et al., 2021; Araujo et al., 2023; Wang & Manchester, 2023) and robust reinforcement learning (Russo & Proutiere, 2021; Barbara et al., 2024).

Refer to caption
Model inv. Lip. (\downarrow) Lip. (\uparrow) loss (\downarrow)
i-ResNet 0.80 4.69 0.2090
i-DenseNet 0.82 4.66 0.2091
BiLipNet 0.11 9.97 0.0685
Best Possible 0.10 10.0 0.0677
Figure 1: Fitting a step function, which is not Lipschitz, with certified (0.1,10)0.110(0.1,10)( 0.1 , 10 )-Lipschitz models. Compared to the analytically-computed optimum, the proposed BiLipNet achieves much tighter bounds than models based on spectral normalization.

Another input-output property – invertibility has received much attention in the deep learning literature since the introduction of normalizing flows (Dinh et al., 2015) for probability-density learning. Invertible neural networks have been applied in applications such as generative modeling (Dinh et al., 2017; Kingma & Dhariwal, 2018), probabilistic inference (Bauer & Mnih, 2019; Ward et al., 2019; Louizos & Welling, 2017), solving inverse problems (Ardizzone et al., 2018) and uncertainty estimation (Liu et al., 2020). A common way to construct invertible networks is to compose invertible affine transformations with more sophisticated invertible layers, including coupling flows (Dinh et al., 2017; Kingma & Dhariwal, 2018), auto-regressive models (Huang et al., 2018; De Cao et al., 2020; Ho et al., 2019), invertible residual layers (Chen et al., 2019; Behrmann et al., 2019), monotone networks (Ahn et al., 2022), and neural ordinary differential equations (Grathwohl et al., 2019), see also in the surveys (Papamakarios et al., 2021; Kobyzev et al., 2020).

However, (Behrmann et al., 2021) observed that commonly-used invertible networks suffer from exploding inverses and are thus prone to becoming numerically non-invertible. This observation motivates the input-output property of bi-Lipschitzness. A layer :nn:superscript𝑛superscript𝑛{\mathcal{F}}:\mathbb{R}^{n}\rightarrow\mathbb{R}^{n}caligraphic_F : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT is said to be bi-Lipschitz with bound of (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν ), or simply (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz, if the following inequalities hold for all x,xn𝑥superscript𝑥superscript𝑛x,x^{\prime}\in\mathbb{R}^{n}italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT:

μxx(x)(x)νxx,𝜇norm𝑥superscript𝑥norm𝑥superscript𝑥𝜈norm𝑥superscript𝑥\mu\|x-x^{\prime}\|\leq\|{\mathcal{F}}(x)-{\mathcal{F}}(x^{\prime})\|\leq\nu\|% x-x^{\prime}\|,italic_μ ∥ italic_x - italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ ≤ ∥ caligraphic_F ( italic_x ) - caligraphic_F ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ ≤ italic_ν ∥ italic_x - italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ ,

where \|\cdot\|∥ ⋅ ∥ is the Euclidean norm. The bound ν𝜈\nuitalic_ν controls the output sensitivity to input perturbations while μ𝜇\muitalic_μ controls the input distinguishability from different outputs (Liu et al., 2020). We call μ𝜇\muitalic_μ as the inverse Lipschitz bound of {\mathcal{F}}caligraphic_F as 1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT exists and is 1/μ1𝜇1/\mu1 / italic_μ-Lipschitz. The ratio τ:=ν/μassign𝜏𝜈𝜇\tau:=\nu/\muitalic_τ := italic_ν / italic_μ is called distortion (Liang et al., 2023), which is the upper bound of the condition number of the Jacobian matrix of {\mathcal{F}}caligraphic_F. A larger distortion implies more expressive flexibility in the model.

In this paper we argue that the bi-Lipschitz property is also useful for learning of surrogate loss (or reward) functions. Given some input/output pairs of a loss function, the objective is to learn a function which matches the observed data and is “easy to optimize” in some sense. This problem appears in many areas, including Q-learning with continuous action spaces, see e.g. (Gu et al., 2016; Amos et al., 2017; Ryu et al., 2019), offline data-driven optimization (Grudzien et al., 2024), learning reward models in inverse reinforcement learning (Arora & Doshi, 2021), and data-driven surrogate losses for engineering process optimization (Cozad et al., 2014; Misener & Biegler, 2023). An important contribution was the input convex neural network (ICNN) (Amos et al., 2017). However, the requirement of input convexity could be too strong in many applications.

1.1 Contributions

  • We propose a novel strongly monotone and Lipschitz residual layer of the form (x)=μx+(x)𝑥𝜇𝑥𝑥{\mathcal{F}}(x)=\mu x+{\mathcal{H}}(x)caligraphic_F ( italic_x ) = italic_μ italic_x + caligraphic_H ( italic_x ). For the nonlinear block {\mathcal{H}}caligraphic_H, we introduce a new architecture – feed-through network (FTN), which takes a multi-layer perceptron (MLP) as its backbone and adds connections from each hidden layer to the input and output variables. For deep networks, this architecture can improve the model expressivity without suffering from vanishing gradients.

  • We parameterize FTNs with certified strong monotonicity (which implies inverse Lipschitzness) and Lipschitzness for {\mathcal{F}}caligraphic_F via the integral quadratic constraint (IQC) framework (Megretski & Rantzer, 1997) and the Cayley transform.

  • By composing strongly-monotone and Lipschitz FTN layers with orthogonal affine layers we obtain the BiLipNet, a new network architecture with smoothly-parameterized bi-Lipschitz bounds.

  • We formulate the model inversion 1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT as a three-operator splitting problem, which admits a numerically efficient solver (Davis & Yin, 2017).

  • We introduce a new scalar-output network f:n:𝑓superscript𝑛f:\mathbb{R}^{n}\rightarrow\mathbb{R}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R, which we call a Polyak-Łojasiewicz network (or PLNet) since it satisfies the condition of the same name (Polyak, 1963; Lojasiewicz, 1963). It consists of a bi-Lipschitz network composed with a quadratic potential, and automatically satisfies favourable properties for surrogate loss learning, in particular existence of a unique global optimum which is efficiently computable.

1.2 Related work

Bi-Lipschitz invertible layer.

In literature, there are two types of invertible layers closely related to our models. The first is the invertible residual layer (x)=x+(x)𝑥𝑥𝑥{\mathcal{F}}(x)=x+{\mathcal{H}}(x)caligraphic_F ( italic_x ) = italic_x + caligraphic_H ( italic_x ) (Chen et al., 2019; Behrmann et al., 2019), where the nonlinear block {\mathcal{H}}caligraphic_H is a shallow network with Lipschitz bound of c<1𝑐1c<1italic_c < 1. In (Perugachi-Diaz et al., 2021), {\mathcal{H}}caligraphic_H is further extended to a deep MLP. It is easy to show that {\mathcal{F}}caligraphic_F is (1c)1𝑐(1-c)( 1 - italic_c )-inverse Lipschitz and (1+c)1𝑐(1+c)( 1 + italic_c )-Lipschitz. In both cases, the Lipschitz regularization is via spectral normalization (Miyato et al., 2018), which we observe to be very conservative (see Figure 1). Alternatively, a bi-Lipschitz layer can be defined by an implicit equation (Lu et al., 2021; Ahn et al., 2022). However, these require an iterative solver for both the forward and inverse model inference. In contrast, our model has an explicit forward pass and iterative solution is only required for the inverse.

IQC-based Lipschitz estimation and training.

In (Fazlyab et al., 2019), the IQC framework of (Megretski & Rantzer, 1997) was first applied to obtain accurate Lipschitz bound estimation of deep networks with slope-restricted activations. It was later pointed out by (Wang et al., 2022) that IQC for Lipschitzness (Fazlyab et al., 2019) is Shor’s relaxation of a “Rayleigh quotient” quadratically constrained quadratic programming (QCQP). Direct (i.e. unconstrained) parameterizations based on IQC were were proposed in (Revay et al., 2020) for deep equilibrium networks, in (Araujo et al., 2023) for residual networks, for deep MLPs and CNNs in (Wang & Manchester, 2023), and recurrent models in (Revay et al., 2023). It was pointed out by (Havens et al., 2023) that many recent Lipschitz model parameterizations (Meunier et al., 2022; Prach & Lampert, 2022; Araujo et al., 2023; Wang & Manchester, 2023) are special cases of (Revay et al., 2020). In a recent work (Pauli et al., 2024), the IQC-based Lipschitz estimation was recently extended to more general activations such as GroupSort and MaxMin. All of these are for one-sided (upper) Lipschitzness, whereas our work applies the IQC framework for monotonicity and bi-Lipschitzness.

Bi-Lipschitz networks for learning-based surrogate optimization.

(Liang et al., 2023) uses Bi-Lipschitz networks to learn a surrogate constraint set while our work focuses on surrogate loss learning. Both works take distortion bound as an important regularization technique. The difference is that the distortion estimation in (Liang et al., 2023) is based on data samples while our work offers certified and smoothly-parameterized distortion bounds.

2 Preliminaries

We give some definitions for a mapping :nn:superscript𝑛superscript𝑛{\mathcal{F}}:\mathbb{R}^{n}\rightarrow\mathbb{R}^{n}caligraphic_F : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT.

Definition 2.1.

{\mathcal{F}}caligraphic_F is said to be μ𝜇\muitalic_μ-strongly monotone with μ>0𝜇0\mu>0italic_μ > 0 if for all x,xn𝑥superscript𝑥superscript𝑛x,x^{\prime}\in\mathbb{R}^{n}italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT we have

(x)(x),xxμxx2,𝑥superscript𝑥𝑥superscript𝑥𝜇superscriptnorm𝑥superscript𝑥2\left\langle{{\mathcal{F}}(x)-{\mathcal{F}}(x^{\prime})},{x-x^{\prime}}\right% \rangle\geq\mu\|x-x^{\prime}\|^{2},⟨ caligraphic_F ( italic_x ) - caligraphic_F ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_x - italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ ≥ italic_μ ∥ italic_x - italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,

where ,\left\langle{\cdot},{\cdot}\right\rangle⟨ ⋅ , ⋅ ⟩ is the Euclidean inner product: a,b=ab𝑎𝑏superscript𝑎top𝑏\left\langle{a},{b}\right\rangle=a^{\top}b⟨ italic_a , italic_b ⟩ = italic_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_b. {\mathcal{F}}caligraphic_F is monotone if the above condition holds for μ=0𝜇0\mu=0italic_μ = 0.

Definition 2.2.

{\mathcal{F}}caligraphic_F is said to be ν𝜈\nuitalic_ν-Lipschitz with ν>0𝜈0\nu>0italic_ν > 0 if

(x)(x)νxx,x,xn.formulae-sequencenorm𝑥superscript𝑥𝜈norm𝑥superscript𝑥for-all𝑥superscript𝑥superscript𝑛\|{\mathcal{F}}(x)-{\mathcal{F}}(x^{\prime})\|\leq\nu\|x-x^{\prime}\|,\quad% \forall x,x^{\prime}\in\mathbb{R}^{n}.∥ caligraphic_F ( italic_x ) - caligraphic_F ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ ≤ italic_ν ∥ italic_x - italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ , ∀ italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT .

{\mathcal{F}}caligraphic_F is said to be μ𝜇\muitalic_μ-inverse Lipschitz with μ>0𝜇0\mu>0italic_μ > 0 if

(x)(x)μxx,x1,x2n.formulae-sequencenorm𝑥superscript𝑥𝜇norm𝑥superscript𝑥for-allsubscript𝑥1subscript𝑥2superscript𝑛\|{\mathcal{F}}(x)-{\mathcal{F}}(x^{\prime})\|\geq\mu\|x-x^{\prime}\|,\quad% \forall x_{1},x_{2}\in\mathbb{R}^{n}.∥ caligraphic_F ( italic_x ) - caligraphic_F ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ ≥ italic_μ ∥ italic_x - italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ , ∀ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT .

{\mathcal{F}}caligraphic_F is said to be bi-Lipschitz with νμ>0𝜈𝜇0\nu\geq\mu>0italic_ν ≥ italic_μ > 0, or simply (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz, if it is μ𝜇\muitalic_μ-inverse Lipschitz and ν𝜈\nuitalic_ν-Lipschitz.

For any (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz mapping {\mathcal{F}}caligraphic_F, its inverse 1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT is well-defined and (1/ν,1/μ)1𝜈1𝜇(1/\nu,1/\mu)( 1 / italic_ν , 1 / italic_μ )-Lipschitz (Yeh, 2006). By the Cauchy–Schwarz inequality, strong monotonicity implies inverse Lipschitzness, see Figure 2. A notable difference between monotonicity and bi-Lipschitzness is their composition behaviour. Given two bi-Lipschitz mappings 1,2subscript1subscript2{\mathcal{F}}_{1},{\mathcal{F}}_{2}caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , caligraphic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, their composition =21subscript2subscript1{\mathcal{F}}={\mathcal{F}}_{2}\circ{\mathcal{F}}_{1}caligraphic_F = caligraphic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∘ caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is also bi-Lipschitz with bound of (μ1μ2,ν1ν2)subscript𝜇1subscript𝜇2subscript𝜈1subscript𝜈2(\mu_{1}\mu_{2},\nu_{1}\nu_{2})( italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_ν start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ν start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) where (μ1,ν1)subscript𝜇1subscript𝜈1(\mu_{1},\nu_{1})( italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_ν start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and (μ2,ν2)subscript𝜇2subscript𝜈2(\mu_{2},\nu_{2})( italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_ν start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) are the bi-Lipschitz bounds of 1subscript1{\mathcal{F}}_{1}caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 2subscript2{\mathcal{F}}_{2}caligraphic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, respectively. However, given two strongly monotone 1,2subscript1subscript2{\mathcal{F}}_{1},{\mathcal{F}}_{2}caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , caligraphic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT with monotonicity bounds μ1,μ2subscript𝜇1subscript𝜇2\mu_{1},\mu_{2}italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, the composition =21subscript2subscript1{\mathcal{F}}={\mathcal{F}}_{2}\circ{\mathcal{F}}_{1}caligraphic_F = caligraphic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∘ caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT does not need to be strongly monotone. However, it is still μ1μ2subscript𝜇1subscript𝜇2\mu_{1}\mu_{2}italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-inverse Lipschitz. To quantify the flexibility of bi-Lipschitz maps, we introduce the following:

Definition 2.3.

{\mathcal{F}}caligraphic_F satisfies a distortion bound τ𝜏\tauitalic_τ with τ1𝜏1\tau\geq 1italic_τ ≥ 1 if {\mathcal{F}}caligraphic_F is (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz with τ=ν/μ𝜏𝜈𝜇\tau=\nu/\muitalic_τ = italic_ν / italic_μ.

For an invertible affine mapping (x)=Px+q𝑥𝑃𝑥𝑞{\mathcal{F}}(x)=Px+qcaligraphic_F ( italic_x ) = italic_P italic_x + italic_q, the condition number of P𝑃Pitalic_P is a distortion bound. An orthogonal mapping (i.e., PP=Isuperscript𝑃top𝑃𝐼P^{\top}P=Iitalic_P start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_P = italic_I) has the smallest possible model distortion τ=1𝜏1\tau=1italic_τ = 1. Distortion bounds satisfy a composition property, i.e., if 1,2subscript1subscript2{\mathcal{F}}_{1},{\mathcal{F}}_{2}caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , caligraphic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT have distortion bounds of τ1,τ2subscript𝜏1subscript𝜏2\tau_{1},\tau_{2}italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, then 21subscript2subscript1{\mathcal{F}}_{2}\circ{\mathcal{F}}_{1}caligraphic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∘ caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT satisfies a distortion bound of τ1τ2subscript𝜏1subscript𝜏2\tau_{1}\tau_{2}italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Both {\mathcal{F}}caligraphic_F and 1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT have the same distortion.

Refer to caption
Figure 2: This figure depicts the possible ranges of Δy=(x)(x)Δ𝑦superscript𝑥𝑥\Delta y={\mathcal{F}}(x^{\prime})-{\mathcal{F}}(x)roman_Δ italic_y = caligraphic_F ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - caligraphic_F ( italic_x ) on 2superscript2\mathbb{R}^{2}blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for a given Δx=xxΔ𝑥superscript𝑥𝑥\Delta x=x^{\prime}-xroman_Δ italic_x = italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - italic_x. The ring (blue area) is for (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz {\mathcal{F}}caligraphic_F while the half moon (red area) is for a μ𝜇\muitalic_μ-strongly monotone and ν𝜈\nuitalic_ν-Lipschitz {\mathcal{F}}caligraphic_F. The largest angle between ΔxΔ𝑥\Delta xroman_Δ italic_x and ΔyΔ𝑦\Delta yroman_Δ italic_y satisfies cosα=τ1𝛼superscript𝜏1\cos\alpha=\tau^{-1}roman_cos italic_α = italic_τ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT with τ=ν/μ𝜏𝜈𝜇\tau=\nu/\muitalic_τ = italic_ν / italic_μ as the distortion.

Surrogate loss learning.

Let 𝒟𝒟\mathcal{D}caligraphic_D be a dataset containing finite samples of xinsubscript𝑥𝑖superscript𝑛x_{i}\in\mathbb{R}^{n}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT and yi=𝔣(xi)subscript𝑦𝑖𝔣subscript𝑥𝑖y_{i}=\mathfrak{f}(x_{i})\in\mathbb{R}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = fraktur_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ blackboard_R where 𝔣𝔣\mathfrak{f}fraktur_f is an unknown loss function. The task is to learn a surrogate loss f^^𝑓\hat{f}over^ start_ARG italic_f end_ARG from 𝒟𝒟\mathcal{D}caligraphic_D, i.e., f^=argminf𝔉𝔼(x,y)𝒟[(f(x)y)2]^𝑓subscriptargmin𝑓𝔉subscript𝔼similar-to𝑥𝑦𝒟delimited-[]superscript𝑓𝑥𝑦2\hat{f}=\operatorname*{arg\,min}_{f\in\mathfrak{F}}\,\mathbb{E}_{(x,y)\sim% \mathcal{D}}\left[(f(x)-y)^{2}\right]over^ start_ARG italic_f end_ARG = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_f ∈ fraktur_F end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT ( italic_x , italic_y ) ∼ caligraphic_D end_POSTSUBSCRIPT [ ( italic_f ( italic_x ) - italic_y ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] where 𝔉𝔉\mathfrak{F}fraktur_F is the model set (e.g. neural networks). In many applications, it is highly desirable that each f𝔉𝑓𝔉f\in\mathfrak{F}italic_f ∈ fraktur_F has a unique and efficiently-computable global minimum. An important model class is the input convex neural network (ICNN) (Amos et al., 2017). Since f𝔉𝑓𝔉f\in\mathfrak{F}italic_f ∈ fraktur_F is convex w.r.t x𝑥xitalic_x, then any local minimum is a global minimum. Moreover, there exists a rich literature for convex optimization. Although convexity is more favourable for downstream optimization problems, it might be a very stringent requirement for fitting the dataset 𝒟𝒟\mathcal{D}caligraphic_D. In this work we aim to construct a model set 𝔉𝔉\mathfrak{F}fraktur_F such that every f𝔉𝑓𝔉f\in\mathfrak{F}italic_f ∈ fraktur_F does not need to be convex but still poses those favourable properties for optimization. In Section 5, we will show that the construction of such model set relies on bi-Lipschitz neural networks.

3 Monotone and bi-Lipschitz Networks

In this section we first present the construction of μ𝜇\muitalic_μ-strongly monotone and ν𝜈\nuitalic_ν-Lipschitz residual layers of the form (x)=μx+(x)𝑥𝜇𝑥𝑥{\mathcal{F}}(x)=\mu x+{\mathcal{H}}(x)caligraphic_F ( italic_x ) = italic_μ italic_x + caligraphic_H ( italic_x ). We then construct bi-Lipschitz networks by deep composition of the new monotone and Lipschitz layers with orthogonal linear layers.

3.1 Feed-through network

For the nonlinear block {\mathcal{H}}caligraphic_H, we introduce a network architecture, called feed-through network (FTN), which takes an MLP as its backbone and then connects each hidden layer to input and output variables, see Figure 3. To be specific, the residual layer (x)=μx+(x)𝑥𝜇𝑥𝑥{\mathcal{F}}(x)=\mu x+{\mathcal{H}}(x)caligraphic_F ( italic_x ) = italic_μ italic_x + caligraphic_H ( italic_x ) can be written as

zk=σ(Wkzk1+Ukx+bk),z0=0y=μx+k=1LYkzk+byformulae-sequencesubscript𝑧𝑘𝜎subscript𝑊𝑘subscript𝑧𝑘1subscript𝑈𝑘𝑥subscript𝑏𝑘subscript𝑧00𝑦𝜇𝑥superscriptsubscript𝑘1𝐿subscript𝑌𝑘subscript𝑧𝑘subscript𝑏𝑦\begin{split}z_{k}&=\sigma(W_{k}z_{k-1}+U_{k}x+b_{k}),\;z_{0}=0\\ y&=\mu x+\sum_{k=1}^{L}Y_{k}z_{k}+b_{y}\end{split}start_ROW start_CELL italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT + italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_x + italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0 end_CELL end_ROW start_ROW start_CELL italic_y end_CELL start_CELL = italic_μ italic_x + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_Y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_CELL end_ROW (1)

where zkmksubscript𝑧𝑘superscriptsubscript𝑚𝑘z_{k}\in\mathbb{R}^{m_{k}}italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are the hidden variables, Uk,Wk,Yksubscript𝑈𝑘subscript𝑊𝑘subscript𝑌𝑘U_{k},W_{k},Y_{k}italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_Y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and bk,bysubscript𝑏𝑘subscript𝑏𝑦b_{k},b_{y}italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT are the learnable weights and biases, respectively. Throughout the paper we assume that the activation σ𝜎\sigmaitalic_σ is a scalar nonlinearity with slope restricted in [0,1]01[0,1][ 0 , 1 ], which is satisfied (possibly with rescaling) by common activation functions such as ReLU, tanh, and sigmoid.

Remark 3.1.

FTN contains both short paths xziy𝑥subscript𝑧𝑖𝑦x\rightarrow z_{i}\rightarrow yitalic_x → italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT → italic_y preventing vanishing gradients and long paths xzizjy𝑥subscript𝑧𝑖subscript𝑧𝑗𝑦x\rightarrow z_{i}\rightarrow\cdots\rightarrow z_{j}\rightarrow yitalic_x → italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT → ⋯ → italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT → italic_y improving model expressivity (see Figure 3).

Refer to caption
Figure 3: The proposed invertible residual network (x)=μx+(x)𝑥𝜇𝑥𝑥{\mathcal{F}}(x)=\mu x+{\mathcal{H}}(x)caligraphic_F ( italic_x ) = italic_μ italic_x + caligraphic_H ( italic_x ) where the nonlinear block {\mathcal{H}}caligraphic_H is a feed-through network, whose hidden layers are directly connected to the input and output.

3.2 SDP conditions for monotonicity and Lipschitzness

The first step towards our parameterization is to establish strong monotonicity and Lipschitzness for {\mathcal{F}}caligraphic_F via semidefinite programming (SDP) conditions. For this, we rewrite {\mathcal{F}}caligraphic_F in a compact form:

z=σ(Wz+Ux+b),y=μx+Yz+byformulae-sequence𝑧𝜎𝑊𝑧𝑈𝑥𝑏𝑦𝜇𝑥𝑌𝑧subscript𝑏𝑦z=\sigma(Wz+Ux+b),\quad y=\mu x+Yz+b_{y}italic_z = italic_σ ( italic_W italic_z + italic_U italic_x + italic_b ) , italic_y = italic_μ italic_x + italic_Y italic_z + italic_b start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT (2)

where z=[z1zL],b=[b1bL]formulae-sequence𝑧superscriptdelimited-[]superscriptsubscript𝑧1topsuperscriptsubscript𝑧𝐿toptop𝑏superscriptdelimited-[]superscriptsubscript𝑏1topsuperscriptsubscript𝑏𝐿toptopz=\bigl{[}\,z_{1}^{\top}\;\cdots\;z_{L}^{\top}\,\bigr{]}^{\top},\,b=\bigl{[}\,% b_{1}^{\top}\;\cdots\;b_{L}^{\top}\,\bigr{]}^{\top}italic_z = [ italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⋯ italic_z start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , italic_b = [ italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⋯ italic_b start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, and

W=[0W20WL0],U=[U1U2UL],Y=[Y1Y2YL].\begin{split}W&=\begin{bmatrix}0&\\ W_{2}&0\\ &\ddots&\ddots\\ &&W_{L}&0\end{bmatrix},\quad U=\begin{bmatrix}U_{1}\\ U_{2}\\ \vdots\\ U_{L}\end{bmatrix},\\ Y&=\begin{bmatrix}Y_{1}&Y_{2}&\cdots&Y_{L}\end{bmatrix}.\end{split}start_ROW start_CELL italic_W end_CELL start_CELL = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL italic_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] , italic_U = [ start_ARG start_ROW start_CELL italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_U start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_U start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] , end_CELL end_ROW start_ROW start_CELL italic_Y end_CELL start_CELL = [ start_ARG start_ROW start_CELL italic_Y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_Y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_Y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] . end_CELL end_ROW
Theorem 3.2.

{\mathcal{F}}caligraphic_F is μ𝜇\muitalic_μ-strongly monotone and ν𝜈\nuitalic_ν-Lipschitz if there exists a Λ𝔻+mΛsuperscriptsubscript𝔻𝑚\Lambda\in{\mathbb{D}}_{+}^{m}roman_Λ ∈ blackboard_D start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT, where 𝔻+msuperscriptsubscript𝔻𝑚{\mathbb{D}}_{+}^{m}blackboard_D start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT is the set of positive diagonal matrices, such that the following conditions hold:

Y=UΛ,2ΛΛWWΛ2γYYformulae-sequence𝑌superscript𝑈topΛsucceeds-or-equals2ΛΛ𝑊superscript𝑊topΛ2𝛾superscript𝑌top𝑌Y=U^{\top}\Lambda,\quad 2\Lambda-\Lambda W-W^{\top}\Lambda\succeq\frac{2}{% \gamma}Y^{\top}Yitalic_Y = italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ , 2 roman_Λ - roman_Λ italic_W - italic_W start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ ⪰ divide start_ARG 2 end_ARG start_ARG italic_γ end_ARG italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_Y (3)

where γ=νμ>0𝛾𝜈𝜇0\gamma=\nu-\mu>0italic_γ = italic_ν - italic_μ > 0.

Remark 3.3.

The above conditions are obtained by applying the IQC theory (Megretski & Rantzer, 1997) to (2).

3.3 Model parameterization

Let ΘΘ\Thetaroman_Θ be the set of all θ={U,W,Y,Λ}𝜃𝑈𝑊𝑌Λ\theta=\{U,W,Y,\Lambda\}italic_θ = { italic_U , italic_W , italic_Y , roman_Λ } such that Condition (3) holds. Since it is generally not scalable to train a model with SDP constraints, we instead construct a direct parameterization, i.e. both unconstrained and complete:

Definition 3.4.

A direct parameterization of a constraint set ΘΘ\Thetaroman_Θ is a surjective differentiable mapping :NΘ:superscript𝑁Θ{\mathcal{M}}:\mathbb{R}^{N}\to\Thetacaligraphic_M : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT → roman_Θ, i.e. for any ϕNitalic-ϕsuperscript𝑁\phi\in\mathbb{R}^{N}italic_ϕ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT we have (ϕ)Θitalic-ϕΘ{\mathcal{M}}(\phi)\in\Thetacaligraphic_M ( italic_ϕ ) ∈ roman_Θ, and the image of Nsuperscript𝑁\mathbb{R}^{N}blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT maps onto ΘΘ\Thetaroman_Θ, i.e. (N)=Θsuperscript𝑁Θ{\mathcal{M}}(\mathbb{R}^{N})=\Thetacaligraphic_M ( blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ) = roman_Θ.

A direct parameterization allows us to replace a constrained optimization over θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ with an unconstrained optimization over ϕNitalic-ϕsuperscript𝑁\phi\in\mathbb{R}^{N}italic_ϕ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT without loss of generality. This enables use of standard first-order optimization algorithms such as SGD or ADAM (Kingma & Ba, 2015).

We now construct a direct parameterization for FTNs satisfying (3). Here we present the main ideas, see Appendix A for full details. First, we introduce the free parameters

ϕ={Fp,Fq}{dk,Fka,Fkb}1kLitalic-ϕsuperscript𝐹𝑝superscript𝐹𝑞subscriptsubscript𝑑𝑘superscriptsubscript𝐹𝑘𝑎superscriptsubscript𝐹𝑘𝑏1𝑘𝐿\phi=\{F^{p},F^{q}\}\cup\{d_{k},F_{k}^{a},F_{k}^{b}\}_{1\leq k\leq L}italic_ϕ = { italic_F start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT , italic_F start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT } ∪ { italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT , italic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT 1 ≤ italic_k ≤ italic_L end_POSTSUBSCRIPT

where Fpn×nsuperscript𝐹𝑝superscript𝑛𝑛F^{p}\in\mathbb{R}^{n\times n}italic_F start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_n end_POSTSUPERSCRIPT, Fqm×nsuperscript𝐹𝑞superscript𝑚𝑛F^{q}\in\mathbb{R}^{m\times n}italic_F start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT, dkmksubscript𝑑𝑘superscriptsubscript𝑚𝑘d_{k}\in\mathbb{R}^{m_{k}}italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, Fkamk×mksuperscriptsubscript𝐹𝑘𝑎superscriptsubscript𝑚𝑘subscript𝑚𝑘F_{k}^{a}\in\mathbb{R}^{m_{k}\times m_{k}}italic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and Fkbmk1×mksuperscriptsubscript𝐹𝑘𝑏superscriptsubscript𝑚𝑘1subscript𝑚𝑘F_{k}^{b}\in\mathbb{R}^{m_{k-1}\times m_{k}}italic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT × italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with m0=0subscript𝑚00m_{0}=0italic_m start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0. Then, we compute some intermediate variables Ψk=diag(edk)\Psi_{k}=\mathrm{diag}\bigl{(}e^{d_{k}}\bigl{)}roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = roman_diag ( italic_e start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) and

[AkBk]=Cayley([FkaFkb]),[PQ]=Cayley([FpFq])formulae-sequencematrixsuperscriptsubscript𝐴𝑘topsuperscriptsubscript𝐵𝑘topCayleymatrixsuperscriptsubscript𝐹𝑘𝑎superscriptsubscript𝐹𝑘𝑏matrix𝑃𝑄Cayleymatrixsuperscript𝐹𝑝superscript𝐹𝑞\begin{bmatrix}A_{k}^{\top}\\ B_{k}^{\top}\end{bmatrix}=\operatorname{Cayley}\left(\begin{bmatrix}F_{k}^{a}% \\ F_{k}^{b}\end{bmatrix}\right),\;\begin{bmatrix}P\\ Q\end{bmatrix}=\operatorname{Cayley}\left(\begin{bmatrix}F^{p}\\ F^{q}\end{bmatrix}\right)[ start_ARG start_ROW start_CELL italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] = roman_Cayley ( [ start_ARG start_ROW start_CELL italic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] ) , [ start_ARG start_ROW start_CELL italic_P end_CELL end_ROW start_ROW start_CELL italic_Q end_CELL end_ROW end_ARG ] = roman_Cayley ( [ start_ARG start_ROW start_CELL italic_F start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_F start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] )

where Cayley:n×pn×p:Cayleysuperscript𝑛𝑝superscript𝑛𝑝\operatorname{Cayley}:\mathbb{R}^{n\times p}\rightarrow\mathbb{R}^{n\times p}roman_Cayley : blackboard_R start_POSTSUPERSCRIPT italic_n × italic_p end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n × italic_p end_POSTSUPERSCRIPT with np𝑛𝑝n\geq pitalic_n ≥ italic_p is defined by

J=Cayley([GH]):=[(I+Z)1(IZ)2H(I+Z)1]𝐽Cayleymatrix𝐺𝐻assignmatrixsuperscript𝐼𝑍1𝐼𝑍2𝐻superscript𝐼𝑍1J=\operatorname{Cayley}\left(\begin{bmatrix}G\\ H\end{bmatrix}\right):=\begin{bmatrix}(I+Z)^{-1}(I-Z)\\ -2H(I+Z)^{-1}\end{bmatrix}italic_J = roman_Cayley ( [ start_ARG start_ROW start_CELL italic_G end_CELL end_ROW start_ROW start_CELL italic_H end_CELL end_ROW end_ARG ] ) := [ start_ARG start_ROW start_CELL ( italic_I + italic_Z ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_I - italic_Z ) end_CELL end_ROW start_ROW start_CELL - 2 italic_H ( italic_I + italic_Z ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] (4)

with Z=GG+HH𝑍superscript𝐺top𝐺superscript𝐻top𝐻Z=G^{\top}-G+H^{\top}Hitalic_Z = italic_G start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - italic_G + italic_H start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_H. It can be verified that JJ=Isuperscript𝐽top𝐽𝐼J^{\top}J=Iitalic_J start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_J = italic_I for any Gp×p𝐺superscript𝑝𝑝G\in\mathbb{R}^{p\times p}italic_G ∈ blackboard_R start_POSTSUPERSCRIPT italic_p × italic_p end_POSTSUPERSCRIPT and H(np)×p𝐻superscript𝑛𝑝𝑝H\in\mathbb{R}^{(n-p)\times p}italic_H ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_n - italic_p ) × italic_p end_POSTSUPERSCRIPT. Note that P𝑃Pitalic_P will not be used for further weight construction as its purpose is to ensure that QQIprecedes-or-equalssuperscript𝑄top𝑄𝐼Q^{\top}Q\preceq Iitalic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_Q ⪯ italic_I. Next we set

Vk=2BkAk1,Sk=AkQkBkQk1formulae-sequencesubscript𝑉𝑘2subscript𝐵𝑘superscriptsubscript𝐴𝑘1topsubscript𝑆𝑘subscript𝐴𝑘subscript𝑄𝑘subscript𝐵𝑘subscript𝑄𝑘1V_{k}=2B_{k}A_{k-1}^{\top},\quad S_{k}=A_{k}Q_{k}-B_{k}Q_{k-1}italic_V start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = 2 italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT

where Q=[Q1QL]𝑄superscriptdelimited-[]superscriptsubscript𝑄1topsuperscriptsubscript𝑄𝐿toptopQ=\bigl{[}\,Q_{1}^{\top}\;\cdots\;Q_{L}^{\top}\,\bigr{]}^{\top}italic_Q = [ italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⋯ italic_Q start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and B1=0subscript𝐵10B_{1}=0italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0, Q0=0subscript𝑄00Q_{0}=0italic_Q start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0. Finally, we construct the weights in (1) as:

Uk=2γΨk1Sk,Wk=Ψk1VkΨk1,Yk=γ2SkΨk,Λk=12Ψk2.formulae-sequencesubscript𝑈𝑘2𝛾superscriptsubscriptΨ𝑘1subscript𝑆𝑘formulae-sequencesubscript𝑊𝑘superscriptsubscriptΨ𝑘1subscript𝑉𝑘subscriptΨ𝑘1formulae-sequencesubscript𝑌𝑘𝛾2superscriptsubscript𝑆𝑘topsubscriptΨ𝑘subscriptΛ𝑘12superscriptsubscriptΨ𝑘2\begin{split}U_{k}&=\sqrt{2\gamma}\Psi_{k}^{-1}S_{k},\;W_{k}=\Psi_{k}^{-1}V_{k% }\Psi_{k-1},\\ Y_{k}&=\sqrt{\frac{\gamma}{2}}S_{k}^{\top}\Psi_{k},\;\Lambda_{k}=\frac{1}{2}% \Psi_{k}^{2}.\end{split}start_ROW start_CELL italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL = square-root start_ARG 2 italic_γ end_ARG roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_V start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_Ψ start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL italic_Y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL = square-root start_ARG divide start_ARG italic_γ end_ARG start_ARG 2 end_ARG end_ARG italic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , roman_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . end_CELL end_ROW (5)
Proposition 3.5.

The model parameterization {\mathcal{M}}caligraphic_M defined in (5) is a direct parameterization for the set ΘΘ\Thetaroman_Θ, i.e. all models (1) satisfying Condition (3).

This means that we can learn the free parameter ϕitalic-ϕ\phiitalic_ϕ using first-order methods without any loss of model expressivity.

The construction is now done, but we note that ΨksubscriptΨ𝑘\Psi_{k}roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is shared between layers k𝑘kitalic_k and k+1𝑘1k+1italic_k + 1. To have a modular implementation, we introduce new variables z^=Ψz^𝑧Ψ𝑧\hat{z}=\Psi zover^ start_ARG italic_z end_ARG = roman_Ψ italic_z and bias b^=Ψb^𝑏Ψ𝑏\hat{b}=\Psi bover^ start_ARG italic_b end_ARG = roman_Ψ italic_b with Ψ=diag(Ψ1,Ψ2,,ΨL)ΨdiagsubscriptΨ1subscriptΨ2subscriptΨ𝐿\Psi=\operatorname{diag}(\Psi_{1},\Psi_{2},\ldots,\Psi_{L})roman_Ψ = roman_diag ( roman_Ψ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , roman_Ψ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , roman_Ψ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ). Then, (2) can be rewritten as follows (see Appendix A)

z^=σ^(Vz^+2γSx+b^),y=μx+γ/2Sz^+byformulae-sequence^𝑧^𝜎𝑉^𝑧2𝛾𝑆𝑥^𝑏𝑦𝜇𝑥𝛾2superscript𝑆top^𝑧subscript𝑏𝑦\hat{z}=\hat{\sigma}\bigl{(}V\hat{z}+\sqrt{2\gamma}Sx+\hat{b}\bigr{)},\;y=\mu x% +\sqrt{\gamma/2}S^{\top}\hat{z}+b_{y}over^ start_ARG italic_z end_ARG = over^ start_ARG italic_σ end_ARG ( italic_V over^ start_ARG italic_z end_ARG + square-root start_ARG 2 italic_γ end_ARG italic_S italic_x + over^ start_ARG italic_b end_ARG ) , italic_y = italic_μ italic_x + square-root start_ARG italic_γ / 2 end_ARG italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG italic_z end_ARG + italic_b start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT (6)

where σ^(x):=Ψσ(Ψ1x)assign^𝜎𝑥Ψ𝜎superscriptΨ1𝑥\hat{\sigma}(x):=\Psi\sigma\left(\Psi^{-1}x\right)over^ start_ARG italic_σ end_ARG ( italic_x ) := roman_Ψ italic_σ ( roman_Ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_x ) is a (0,1)01(0,1)( 0 , 1 )-Lipschitz layer with learnable scaling ΨΨ\Psiroman_Ψ, the weights S,V𝑆𝑉S,Vitalic_S , italic_V can be written as

S=[S1S2SL],V=[0V20VL0].formulae-sequence𝑆matrixsubscript𝑆1subscript𝑆2subscript𝑆𝐿𝑉matrix0missing-subexpressionsubscript𝑉20missing-subexpressionmissing-subexpressionmissing-subexpressionsubscript𝑉𝐿0S=\begin{bmatrix}S_{1}\\ S_{2}\\ \vdots\\ S_{L}\end{bmatrix},\quad V=\begin{bmatrix}0&\\ V_{2}&0\\ &\ddots&\ddots\\ &&V_{L}&0\end{bmatrix}.italic_S = [ start_ARG start_ROW start_CELL italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_S start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] , italic_V = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL italic_V start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] . (7)

Number of free parameters.

Consider an L𝐿Litalic_L-layer FTN (1) where each layer has the same width, i.e. mk=dsubscript𝑚𝑘𝑑m_{k}=ditalic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_d. The bi-Lipschitz network based on spectral normalization (Liu et al., 2020) has 2Ld22𝐿superscript𝑑22Ld^{2}2 italic_L italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT free parameters while our model size is (3L+1)d2+Ld3𝐿1superscript𝑑2𝐿𝑑(3L+1)d^{2}+Ld( 3 italic_L + 1 ) italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_L italic_d. Since the Ld2𝐿superscript𝑑2Ld^{2}italic_L italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT term dominates for deep and wide networks, our model has roughly 1.5 times as many parameter as the model from (Liu et al., 2020).

3.4 Bi-Lipschitz networks

We construct bi-Lipschitz networks (referred as BiLipNets) by composing strongly monotone and Lipschitz layers,

𝒢=𝒪K+1K𝒪KK1𝒪21𝒪1𝒢subscript𝒪𝐾1subscript𝐾subscript𝒪𝐾subscript𝐾1subscript𝒪2subscript1subscript𝒪1{\mathcal{G}}={\mathcal{O}}_{K+1}\circ{\mathcal{F}}_{K}\circ{\mathcal{O}}_{K}% \circ{\mathcal{F}}_{K-1}\circ\cdots\circ{\mathcal{O}}_{2}\circ{\mathcal{F}}_{1% }\circ{\mathcal{O}}_{1}caligraphic_G = caligraphic_O start_POSTSUBSCRIPT italic_K + 1 end_POSTSUBSCRIPT ∘ caligraphic_F start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∘ caligraphic_O start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∘ caligraphic_F start_POSTSUBSCRIPT italic_K - 1 end_POSTSUBSCRIPT ∘ ⋯ ∘ caligraphic_O start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∘ caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∘ caligraphic_O start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT (8)

where 𝒪k(x)=Pkx+qksubscript𝒪𝑘𝑥subscript𝑃𝑘𝑥subscript𝑞𝑘{\mathcal{O}}_{k}(x)=P_{k}x+q_{k}caligraphic_O start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x ) = italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_x + italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT with PkPk=Isuperscriptsubscript𝑃𝑘topsubscript𝑃𝑘𝐼P_{k}^{\top}P_{k}=Iitalic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_I is an orthogonal layer and ksubscript𝑘{\mathcal{F}}_{k}caligraphic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is a μksubscript𝜇𝑘\mu_{k}italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT-strongly monotone and νksubscript𝜈𝑘\nu_{k}italic_ν start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT-Lipschitz layer (6). By the composition rule, the above BiLipNet is (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz with μ=k=1Lμk𝜇superscriptsubscriptproduct𝑘1𝐿subscript𝜇𝑘\mu=\prod_{k=1}^{L}\mu_{k}italic_μ = ∏ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and ν=k=1Lνk𝜈superscriptsubscriptproduct𝑘1𝐿subscript𝜈𝑘\nu=\prod_{k=1}^{L}\nu_{k}italic_ν = ∏ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_ν start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. The orthogonal matrix P𝑃Pitalic_P can be parameterized via the Cayley transformation (4) or Householder transformation (Singla et al., 2022). Since the distortion of 𝒪ksubscript𝒪𝑘{\mathcal{O}}_{k}caligraphic_O start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is 1, it can improve network expressivity without increasing model distortion.

In some applications, e.g., normalising flows (Dinh et al., 2015; Papamakarios et al., 2021), we need to compute the inverse of 𝒢𝒢{\mathcal{G}}caligraphic_G, which can be done in a backward manner:

𝒢1(y)=𝒪1111𝒪K1K1𝒪K+11(y),superscript𝒢1𝑦superscriptsubscript𝒪11superscriptsubscript11superscriptsubscript𝒪𝐾1superscriptsubscript𝐾1superscriptsubscript𝒪𝐾11𝑦{\mathcal{G}}^{-1}(y)={\mathcal{O}}_{1}^{-1}\circ{\mathcal{F}}_{1}^{-1}\circ% \cdots\circ{\mathcal{O}}_{K}^{-1}\circ{\mathcal{F}}_{K}^{-1}\circ{\mathcal{O}}% _{K+1}^{-1}(y),caligraphic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_y ) = caligraphic_O start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ caligraphic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ ⋯ ∘ caligraphic_O start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ caligraphic_F start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∘ caligraphic_O start_POSTSUBSCRIPT italic_K + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_y ) , (9)

where 𝒪ksubscript𝒪𝑘{\mathcal{O}}_{k}caligraphic_O start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT has an explicit inverse 𝒪k1(y)=Pk(yqk)superscriptsubscript𝒪𝑘1𝑦superscriptsubscript𝑃𝑘top𝑦subscript𝑞𝑘{\mathcal{O}}_{k}^{-1}(y)=P_{k}^{\top}(y-q_{k})caligraphic_O start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_y ) = italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_y - italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ). Computing the inverse k1(y)superscriptsubscript𝑘1𝑦{\mathcal{F}}_{k}^{-1}(y)caligraphic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_y ) requires an iterative solver, which will be addressed in Section 4.

Partially bi-Lipschitz networks.

A neural network 𝒢~:n×ln:~𝒢superscript𝑛superscript𝑙superscript𝑛\tilde{\mathcal{G}}:\mathbb{R}^{n}\times\mathbb{R}^{l}\rightarrow\mathbb{R}^{n}over~ start_ARG caligraphic_G end_ARG : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT is said to be partially bi-Lipschitz if for any fixed value of pl𝑝superscript𝑙p\in\mathbb{R}^{l}italic_p ∈ blackboard_R start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT, the mapping y=𝒢~(x;p)𝑦~𝒢𝑥𝑝y=\tilde{\mathcal{G}}(x;p)italic_y = over~ start_ARG caligraphic_G end_ARG ( italic_x ; italic_p ) is (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz from x𝑥xitalic_x to y𝑦yitalic_y. We can construct such mappings via 𝒢~(x;p)=𝒢h(p)(x)~𝒢𝑥𝑝subscript𝒢𝑝𝑥\tilde{\mathcal{G}}(x;p)={\mathcal{G}}_{h(p)}(x)over~ start_ARG caligraphic_G end_ARG ( italic_x ; italic_p ) = caligraphic_G start_POSTSUBSCRIPT italic_h ( italic_p ) end_POSTSUBSCRIPT ( italic_x ) where 𝒢ϕsubscript𝒢italic-ϕ{\mathcal{G}}_{\phi}caligraphic_G start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT is a(μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz network for any free parameter ϕNitalic-ϕsuperscript𝑁\phi\in\mathbb{R}^{N}italic_ϕ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT and h:pϕ:𝑝italic-ϕh:p\rightarrow\phiitalic_h : italic_p → italic_ϕ is a new learnable function. Since the dimension of ϕitalic-ϕ\phiitalic_ϕ is often very high, a practical approach is to make ϕitalic-ϕ\phiitalic_ϕ partially depend on p𝑝pitalic_p. For instance, we can learn p𝑝pitalic_p-dependent bias via an MLP while the weight matrices of 𝒢ϕsubscript𝒢italic-ϕ{\mathcal{G}}_{\phi}caligraphic_G start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT is independent of p𝑝pitalic_p.

4 Model inverse via operator splitting

In this section we give an efficient algorithm to compute 1(y)superscript1𝑦{\mathcal{F}}^{-1}(y)caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_y ) where {\mathcal{F}}caligraphic_F is a μ𝜇\muitalic_μ-strongly monotone and ν𝜈\nuitalic_ν-Lipschitz layer (6). First, we write its model inverse 1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT as

z^=σ^((VγμSS)z^+bz)x=1μ(ybyγ/2Sz^)^𝑧^𝜎𝑉𝛾𝜇𝑆superscript𝑆top^𝑧subscript𝑏𝑧𝑥1𝜇𝑦subscript𝑏𝑦𝛾2superscript𝑆top^𝑧\begin{split}\hat{z}&=\hat{\sigma}\left(\left(V-\frac{\gamma}{\mu}SS^{\top}% \right)\hat{z}+b_{z}\right)\\ x&=\frac{1}{\mu}(y-b_{y}-\sqrt{\gamma/2}S^{\top}\hat{z})\end{split}start_ROW start_CELL over^ start_ARG italic_z end_ARG end_CELL start_CELL = over^ start_ARG italic_σ end_ARG ( ( italic_V - divide start_ARG italic_γ end_ARG start_ARG italic_μ end_ARG italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) over^ start_ARG italic_z end_ARG + italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_x end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG italic_μ end_ARG ( italic_y - italic_b start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT - square-root start_ARG italic_γ / 2 end_ARG italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG italic_z end_ARG ) end_CELL end_ROW (10)

with bz=2γ/μS(yby)+b^subscript𝑏𝑧2𝛾𝜇𝑆𝑦subscript𝑏𝑦^𝑏b_{z}=\sqrt{2\gamma}/\mu S(y-b_{y})+\hat{b}italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT = square-root start_ARG 2 italic_γ end_ARG / italic_μ italic_S ( italic_y - italic_b start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) + over^ start_ARG italic_b end_ARG. Both {\mathcal{F}}caligraphic_F and 1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT can be treated as special cases of deep equilibrium networks (Bai et al., 2019; Winston & Kolter, 2020; Revay et al., 2020) or implicit networks (El Ghaoui et al., 2021). The difference is that {\mathcal{F}}caligraphic_F has an explicit formula due to the strictly lower-triangular V𝑉Vitalic_V while 1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT is an implicit equation as SS𝑆superscript𝑆topSS^{\top}italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT is a full matrix. A natural question for (10) is its well-posedness, i.e., for any yn𝑦superscript𝑛y\in\mathbb{R}^{n}italic_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, does there exists a unique z^m^𝑧superscript𝑚\hat{z}\in\mathbb{R}^{m}over^ start_ARG italic_z end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT satisfying (10)?

Proposition 4.1.

1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT is well-posed if V,S𝑉𝑆V,Sitalic_V , italic_S are given by (7).

Certain classes of equilibrium networks were solved via two-operator splitting problems (Winston & Kolter, 2020; Revay et al., 2020). We follow a similar strategy, but our structure admits a three-operator splitting, see Proposition 4.2 with background in Appendix B. To state the result, we first recall the following fact from (Li et al., 2019). For the monotone and 1-Lipschitz activation σ^^𝜎\hat{\sigma}over^ start_ARG italic_σ end_ARG, there exists a proper convex function f:n:𝑓superscript𝑛f:\mathbb{R}^{n}\rightarrow\mathbb{R}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R satisfying σ^()=𝐩𝐫𝐨𝐱f1()^𝜎superscriptsubscript𝐩𝐫𝐨𝐱𝑓1\hat{\sigma}(\cdot)=\mathbf{prox}_{f}^{1}(\cdot)over^ start_ARG italic_σ end_ARG ( ⋅ ) = bold_prox start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( ⋅ ) with

𝐩𝐫𝐨𝐱fα(x)=argminzn12xz2+αf(z).superscriptsubscript𝐩𝐫𝐨𝐱𝑓𝛼𝑥subscript𝑧superscript𝑛12superscriptnorm𝑥𝑧2𝛼𝑓𝑧\mathbf{prox}_{f}^{\alpha}(x)=\arg\min_{z\in\mathbb{R}^{n}}\;\frac{1}{2}\|x-z% \|^{2}+\alpha f(z).bold_prox start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT ( italic_x ) = roman_arg roman_min start_POSTSUBSCRIPT italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ italic_x - italic_z ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_α italic_f ( italic_z ) .

A list of f𝑓fitalic_f for popular activations is given in Section B.1.

Proposition 4.2.

Finding a solution z^m^𝑧superscript𝑚\hat{z}\in\mathbb{R}^{m}over^ start_ARG italic_z end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT to (10) is equivalent to finding a zero to the three-operator splitting problem 0𝒜(z)+(z)+𝒞(z)0𝒜𝑧𝑧𝒞𝑧0\in{\mathcal{A}}(z)+{\mathcal{B}}(z)+{\mathcal{C}}(z)0 ∈ caligraphic_A ( italic_z ) + caligraphic_B ( italic_z ) + caligraphic_C ( italic_z ) where 𝒜,,𝒞𝒜𝒞{\mathcal{A}},{\mathcal{B}},{\mathcal{C}}caligraphic_A , caligraphic_B , caligraphic_C are monotone operators defined by

𝒜(z)=(IV)zbz,(z)=f(z),𝒞(z)=γμSSzformulae-sequence𝒜𝑧𝐼𝑉𝑧subscript𝑏𝑧formulae-sequence𝑧𝑓𝑧𝒞𝑧𝛾𝜇𝑆superscript𝑆top𝑧\begin{split}{\mathcal{A}}(z)=(I-V)z-b_{z},\;{\mathcal{B}}(z)=\partial f(z),\;% {\mathcal{C}}(z)=\frac{\gamma}{\mu}SS^{\top}z\end{split}start_ROW start_CELL caligraphic_A ( italic_z ) = ( italic_I - italic_V ) italic_z - italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT , caligraphic_B ( italic_z ) = ∂ italic_f ( italic_z ) , caligraphic_C ( italic_z ) = divide start_ARG italic_γ end_ARG start_ARG italic_μ end_ARG italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_z end_CELL end_ROW

where f𝑓fitalic_f satisfies σ^()=𝐩𝐫𝐨𝐱f1()^𝜎superscriptsubscript𝐩𝐫𝐨𝐱𝑓1\hat{\sigma}(\cdot)=\mathbf{prox}_{f}^{1}(\cdot)over^ start_ARG italic_σ end_ARG ( ⋅ ) = bold_prox start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( ⋅ ).

For three-operator problems, the Davis-Yin splitting algorithm (DYS) (Davis & Yin, 2017) can be applied, obtaining the following fixed-point iteration:

zk+1/2=𝐩𝐫𝐨𝐱fα(uk)uk+1/2=2zk+1/2ukzk+1=R𝒜(uk+1/2α𝒞(zk+1/2))uk+1=uk+zk+1zk+1/2superscript𝑧𝑘12superscriptsubscript𝐩𝐫𝐨𝐱𝑓𝛼superscript𝑢𝑘superscript𝑢𝑘122superscript𝑧𝑘12superscript𝑢𝑘superscript𝑧𝑘1subscript𝑅𝒜superscript𝑢𝑘12𝛼𝒞superscript𝑧𝑘12superscript𝑢𝑘1superscript𝑢𝑘superscript𝑧𝑘1superscript𝑧𝑘12\begin{split}z^{k+1/2}&=\mathbf{prox}_{f}^{\alpha}(u^{k})\\ u^{k+1/2}&=2z^{k+1/2}-u^{k}\\ z^{k+1}&=R_{{\mathcal{A}}}(u^{k+1/2}-\alpha{\mathcal{C}}(z^{k+1/2}))\\ u^{k+1}&=u^{k}+z^{k+1}-z^{k+1/2}\end{split}start_ROW start_CELL italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT end_CELL start_CELL = bold_prox start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT ( italic_u start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_u start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT end_CELL start_CELL = 2 italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT - italic_u start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_z start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT end_CELL start_CELL = italic_R start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT - italic_α caligraphic_C ( italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL italic_u start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT end_CELL start_CELL = italic_u start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + italic_z start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT - italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT end_CELL end_ROW (11)

where R𝒜(v)=((1+α)IαV)1(v+αbz)subscript𝑅𝒜𝑣superscript1𝛼𝐼𝛼𝑉1𝑣𝛼subscript𝑏𝑧R_{{\mathcal{A}}}(v)=((1+\alpha)I-\alpha V)^{-1}(v+\alpha b_{z})italic_R start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT ( italic_v ) = ( ( 1 + italic_α ) italic_I - italic_α italic_V ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_v + italic_α italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ). Since V𝑉Vitalic_V is strictly lower triangular, we can solve R𝒜(v)subscript𝑅𝒜𝑣R_{{\mathcal{A}}}(v)italic_R start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT ( italic_v ) using forward substitution. Furthermore, we can show that (11) is guaranteed to converge with α(0,1τ1)𝛼01𝜏1\alpha\in\bigl{(}0,\frac{1}{\tau-1}\bigr{)}italic_α ∈ ( 0 , divide start_ARG 1 end_ARG start_ARG italic_τ - 1 end_ARG ), where τ𝜏\tauitalic_τ is the model distortion.

5 Polyak-Łojasiewicz Networks

We call a network f:n:𝑓superscript𝑛f:\mathbb{R}^{n}\rightarrow\mathbb{R}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R a Polyak-Łojasiewicz (PL) network, or PLNet for short, if it satisfies the following PL condition (Polyak, 1963; Lojasiewicz, 1963):

12xf(x)2m(f(x)minxf(x)),xn,formulae-sequence12superscriptnormsubscript𝑥𝑓𝑥2𝑚𝑓𝑥subscript𝑥𝑓𝑥for-all𝑥superscript𝑛\frac{1}{2}\|\nabla_{x}f(x)\|^{2}\geq m(f(x)-\min_{x}f(x)),\,\forall x\in% \mathbb{R}^{n},divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_f ( italic_x ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ italic_m ( italic_f ( italic_x ) - roman_min start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_f ( italic_x ) ) , ∀ italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , (12)

where m>0𝑚0m>0italic_m > 0. The PL condition is significant in optimization since it is weaker than convexity, but still implies that gradient methods converge to a global minimum with a linear rate (Karimi et al., 2016), making PLNet a promising candidate for learning a surrogate loss models.

Proposition 5.1.

If 𝒢𝒢{\mathcal{G}}caligraphic_G is μ𝜇\muitalic_μ-inverse Lipschitz, then

f(x)=12𝒢(x)2+c,cformulae-sequence𝑓𝑥12superscriptnorm𝒢𝑥2𝑐𝑐f(x)=\frac{1}{2}\|{\mathcal{G}}(x)\|^{2}+c,\quad c\in\mathbb{R}italic_f ( italic_x ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ caligraphic_G ( italic_x ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_c , italic_c ∈ blackboard_R (13)

is a PLNet with m=μ2𝑚superscript𝜇2m=\mu^{2}italic_m = italic_μ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

Remark 5.2.

We can further relax the quadratic assumption: f(x)=h(𝒢(x))𝑓𝑥𝒢𝑥f(x)=h\bigl{(}{\mathcal{G}}(x)\bigr{)}italic_f ( italic_x ) = italic_h ( caligraphic_G ( italic_x ) ) is a PLNet if h:n:superscript𝑛h:\mathbb{R}^{n}\rightarrow\mathbb{R}italic_h : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R is strongly convex (Karimi et al., 2016).

Remark 5.3.

For parametric optimization problem, one can learn a surrogate loss via f(x;p)=1/2𝒢~(x;p)2+c𝑓𝑥𝑝12superscriptnorm~𝒢𝑥𝑝2𝑐f(x;p)=1/2\|\tilde{\mathcal{G}}(x;p)\|^{2}+citalic_f ( italic_x ; italic_p ) = 1 / 2 ∥ over~ start_ARG caligraphic_G end_ARG ( italic_x ; italic_p ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_c where pm𝑝superscript𝑚p\in\mathbb{R}^{m}italic_p ∈ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT is the problem-specific parameter and 𝒢~~𝒢\tilde{\mathcal{G}}over~ start_ARG caligraphic_G end_ARG is a partially bi-Lipschitz network.

Remark 5.4.

Any sub-level set 𝕃α={x:f(x)<α}subscript𝕃𝛼conditional-set𝑥𝑓𝑥𝛼{\mathbb{L}}_{\alpha}=\{x:f(x)<\alpha\}blackboard_L start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT = { italic_x : italic_f ( italic_x ) < italic_α } with α>c𝛼𝑐\alpha>citalic_α > italic_c is homeomorphic to a unit ball, making PLNets suitable for neural Lyapunov functions (Wilson, 1967). Applications of PLNets to learning Lyapunov stable neural dynamics can be found in (Cheng et al., 2024).

Computing global optimum of a PLNet.

If f𝑓fitalic_f takes the form (13) and 𝒢𝒢{\mathcal{G}}caligraphic_G is bi-Lipschitz network (8), then f𝑓fitalic_f has a unique global optimum x=𝒢1(0)superscript𝑥superscript𝒢10x^{\star}={\mathcal{G}}^{-1}(0)italic_x start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = caligraphic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( 0 ) with 𝒢1superscript𝒢1{\mathcal{G}}^{-1}caligraphic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT given by (9). This can be efficiently computed by analytical inversion of orthogonal layers and applying the DYS algorithm (11) to monotone and Lipschitz layers.

Limitations of gradient descent for finding global optimum.

An alternative way to compute the global optimum xsuperscript𝑥x^{\star}italic_x start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT is the standard gradient descent (GD) method xk+1=xkαxf(xk)superscript𝑥𝑘1superscript𝑥𝑘𝛼subscript𝑥𝑓superscript𝑥𝑘x^{k+1}=x^{k}-\alpha\nabla_{x}f(x^{k})italic_x start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_α ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ). If xfsubscript𝑥𝑓\nabla_{x}f∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_f is L𝐿Litalic_L-Lipschitz, then the above GD solver with α=1/L𝛼1𝐿\alpha=1/Litalic_α = 1 / italic_L has a linear global convergence rate of 1m/L1𝑚𝐿1-m/L1 - italic_m / italic_L with m=μ2𝑚superscript𝜇2m=\mu^{2}italic_m = italic_μ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (Karimi et al., 2016). However, this method has two drawbacks. First, the gradient function xfsubscript𝑥𝑓\nabla_{x}f∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_f may not be globally Lipschitz, see Example 5.5. Secondly, even if a global Lipschitz bound exists, it is generally hard to estimate.

Example 5.5.

Consider a scalar function f(x)=0.5g2(x)𝑓𝑥0.5superscript𝑔2𝑥f(x)=0.5g^{2}(x)italic_f ( italic_x ) = 0.5 italic_g start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_x ) with g(x)=2x+sinx𝑔𝑥2𝑥𝑥g(x)=2x+\sin xitalic_g ( italic_x ) = 2 italic_x + roman_sin italic_x, which satisfies the PL condition. Note that f/x=(2+cosx)(2x+sinx)𝑓𝑥2𝑥2𝑥𝑥\partial f/\partial x=(2+\cos x)(2x+\sin x)∂ italic_f / ∂ italic_x = ( 2 + roman_cos italic_x ) ( 2 italic_x + roman_sin italic_x ) is not globally Lipschitz due to the term 2xcosx2𝑥𝑥2x\cos x2 italic_x roman_cos italic_x.

6 Experiments

Here we present experiments which explore the expressive quality of the proposed models, regularisation via model distortion, and performance of the DYS solution method. Code is available at https://github.com/acfr/PLNet.

6.1 Uncertainty quantification via neural Gaussian process

It was shown in (Liu et al., 2020) that accurate uncertainty quantification of neural network models depends on a model’s ability to quantify the distance of a test example from the training data. This distance-awareness can be achieved with bi-Lipschitz residual layers (x)=x+(x)𝑥𝑥𝑥{\mathcal{F}}(x)=x+{\mathcal{H}}(x)caligraphic_F ( italic_x ) = italic_x + caligraphic_H ( italic_x ) and a Gaussian process output layer. In (Liu et al., 2020) this is achieved by imposing Lipschitz bound of 0<c<10𝑐10<c<10 < italic_c < 1 for {\mathcal{H}}caligraphic_H via spectral normalization. The resulting model is called Spectral-normalized Neural Gaussian Process (SNGP). In this section we examine the benefits of using the proposed BiLipNet in place of spectrally-normalized layers.

Toy example.

Using the two-moon dataset, we compare our (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz network to an SNGP using a 3-layer i-ResNet under the same bi-Lipschitz constraints, i.e., μ=(1c)3𝜇superscript1𝑐3\mu=(1-c)^{3}italic_μ = ( 1 - italic_c ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and ν=(1+c)3𝜈superscript1𝑐3\nu=(1+c)^{3}italic_ν = ( 1 + italic_c ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT, see Figure 4. For the lower-distortion case (i.e., small c=0.1𝑐0.1c=0.1italic_c = 0.1), SNGP fails to completely separate the train and out-of-distribution (OOD) data due to its loose Lipschitz bound. Our model can distinguish the OOD examples from training dataset and the uncertainty surface is close to the SNGP with much higher distortion (c=0.9𝑐0.9c=0.9italic_c = 0.9). As the model distortion increases, our model can have an uncertainty surface very close the dataset. The uncertainty surface of SNGP does not change much from c=0.1𝑐0.1c=0.1italic_c = 0.1 to c=0.9𝑐0.9c=0.9italic_c = 0.9, see Figure 4 and additional results in Section D.2.

Refer to caption
Figure 4: Predictive uncertainty of different NGPs with the same bi-Lipschitz bound. The points from dark blue and regions are classified as in-domain distribution and OOD data, respectively. Light blue and orange points (different colors indicate different labels) are training samples from the two-moon dataset. The red points are ODD test examples. For the case with small distortion, our model can still distinguish the train and OOD data, achieving similar results of SNGP with large distortion.

CIFAR-10/100.

For image datasets, the SNGP model in (Liu et al., 2020) contains three bi-Lipschitz components, each with four residual layers of the form x+(x)𝑥𝑥x+{\mathcal{H}}(x)italic_x + caligraphic_H ( italic_x ) where {\mathcal{H}}caligraphic_H is constructed to be c𝑐citalic_c-Lipschitz using spectral normalization. To ensure certifiable bi-Lipschitzness, we modify the SNGP model by choosing c(0,1)𝑐01c\in(0,1)italic_c ∈ ( 0 , 1 ) and removing batch normalization from {\mathcal{H}}caligraphic_H since it may re-scale a layer’s spectral norm in unexpected ways (Liu et al., 2023). The results of SNGP with batch normalization can be found in Section D.2. Our BiLipNet model has a similar architecture as SNGP except replacing the bi-Lipschitz components with the proposed (μ,ν)𝜇𝜈(\mu,\nu)( italic_μ , italic_ν )-Lipschitz network (8). To ensure both models have the same bi-Lipschitz bound, we choose μ=(1c)4𝜇superscript1𝑐4\mu=(1-c)^{4}italic_μ = ( 1 - italic_c ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT and ν=(1+c)4𝜈superscript1𝑐4\nu=(1+c)^{4}italic_ν = ( 1 + italic_c ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT.

Table 1 reports the results of SNGP and BiLipNet under different bounds c=0.95,0.65,0.35𝑐0.950.650.35c=0.95,0.65,0.35italic_c = 0.95 , 0.65 , 0.35. For CIFAR-10 dataset, our model uniformly outperforms SNGP on both clean and corrupted data, i.e., it achieves higher accuracy (about 1020%similar-to10percent2010\sim 20\%10 ∼ 20 % improvement), lower expected calibration error (ECE) and negative log liklihood (NLL). Similar conclusion also holds for CIFAR-100 on accuracy and NLL, though our model has sightly higher ECE.

As with the previous toy example, our model with small distortion (τ=18.6𝜏18.6\tau=18.6italic_τ = 18.6 for c=0.35𝑐0.35c=0.35italic_c = 0.35) achieves better accuracy than SNGP with large distortion (τ=2.3×106𝜏2.3superscript106\tau=2.3\times 10^{6}italic_τ = 2.3 × 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT for c=0.95𝑐0.95c=0.95italic_c = 0.95). Thus, we observe that our parameterization is much more expressive for a given distortion bound.

Accuracy (\uparrow) ECE (\downarrow) NLL (\downarrow)
Method c𝑐citalic_c Clean Corrupted Clean Corrupted Clean Corrupted
CIFAR-10
SNGP 0.95 76.7 ±plus-or-minus\pm± 0.629 58.7 ±plus-or-minus\pm± 1.000 0.057 ±plus-or-minus\pm± 0.007 0.079 ±plus-or-minus\pm± 0.006 0.682 ±plus-or-minus\pm± 0.015 1.199 ±plus-or-minus\pm± 0.041
0.65 72.5 ±plus-or-minus\pm± 1.500 54.7 ±plus-or-minus\pm± 1.778 0.058 ±plus-or-minus\pm± 0.006 0.078 ±plus-or-minus\pm± 0.006 0.797 ±plus-or-minus\pm± 0.046 1.303 ±plus-or-minus\pm± 0.057
0.35 62.7 ±plus-or-minus\pm± 0.334 52.3 ±plus-or-minus\pm± 0.721 0.069 ±plus-or-minus\pm± 0.010 0.065 ±plus-or-minus\pm± 0.006 1.055 ±plus-or-minus\pm± 0.010 1.356 ±plus-or-minus\pm± 0.018
BiLipNet 0.95 86.2 ±plus-or-minus\pm± 0.250 70.8 ±plus-or-minus\pm± 0.469 0.020 ±plus-or-minus\pm± 0.003 0.052 ±plus-or-minus\pm± 0.005 0.423 ±plus-or-minus\pm± 0.006 0.895 ±plus-or-minus\pm± 0.020
0.65 86.7 ±plus-or-minus\pm± 0.129 72.8 ±plus-or-minus\pm± 0.592 0.015 ±plus-or-minus\pm± 0.005 0.047 ±plus-or-minus\pm± 0.009 0.400 ±plus-or-minus\pm± 0.006 0.830 ±plus-or-minus\pm± 0.024
0.35 84.5 ±plus-or-minus\pm± 0.184 72.6 ±plus-or-minus\pm± 0.216 0.010 ±plus-or-minus\pm± 0.002 0.052 ±plus-or-minus\pm± 0.004 0.457 ±plus-or-minus\pm± 0.002 0.827 ±plus-or-minus\pm± 0.008
CIFAR-100
SNGP 0.95 36.9 ±plus-or-minus\pm± 1.656 25.5 ±plus-or-minus\pm± 1.406 0.131 ±plus-or-minus\pm± 0.010 0.068 ±plus-or-minus\pm± 0.005 2.493 ±plus-or-minus\pm± 0.068 3.073 ±plus-or-minus\pm± 0.069
0.65 33.0 ±plus-or-minus\pm± 0.481 24.3 ±plus-or-minus\pm± 0.749 0.117 ±plus-or-minus\pm± 0.006 0.068 ±plus-or-minus\pm± 0.003 2.683 ±plus-or-minus\pm± 0.015 3.140 ±plus-or-minus\pm± 0.048
0.35 26.5 ±plus-or-minus\pm± 1.630 19.3 ±plus-or-minus\pm± 1.296 0.101 ±plus-or-minus\pm± 0.016 0.056 ±plus-or-minus\pm± 0.010 3.020 ±plus-or-minus\pm± 0.062 3.406 ±plus-or-minus\pm± 0.073
BiLipNet 0.95 51.0 ±plus-or-minus\pm± 0.480 35.8 ±plus-or-minus\pm± 0.397 0.230 ±plus-or-minus\pm± 0.006 0.137 ±plus-or-minus\pm± 0.007 2.064 ±plus-or-minus\pm± 0.024 2.718 ±plus-or-minus\pm± 0.014
0.65 55.2 ±plus-or-minus\pm± 0.426 39.2 ±plus-or-minus\pm± 0.495 0.225 ±plus-or-minus\pm± 0.004 0.137 ±plus-or-minus\pm± 0.005 1.887 ±plus-or-minus\pm± 0.021 2.576 ±plus-or-minus\pm± 0.022
0.35 54.4 ±plus-or-minus\pm± 0.438 41.1 ±plus-or-minus\pm± 0.200 0.194 ±plus-or-minus\pm± 0.008 0.126 ±plus-or-minus\pm± 0.009 1.876 ±plus-or-minus\pm± 0.031 2.447 ±plus-or-minus\pm± 0.016
Table 1: Results for SNGP and BiLipNet on CIFAR-10/100, averaged over 5 seeds. To ensure bi-Lipschitz bounds, batch normalization is removed from SNGP. BiLipNet uniformly significantly outperforms SNGP in term of accuracy on both clean and corrupted data.

6.2 Surrogate loss learning

We explore the PLNet’s performance with the Rosenbrock function r(x,y)=1/200(x1)2+0.5(yx2)2𝑟𝑥𝑦1200superscript𝑥120.5superscript𝑦superscript𝑥22r(x,y)=1/200(x-1)^{2}+0.5(y-x^{2})^{2}italic_r ( italic_x , italic_y ) = 1 / 200 ( italic_x - 1 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 0.5 ( italic_y - italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and its higher-dimensional generalizations. The Rosenbrock function is a classical test problem for optimization, since it is non-convex but a unique global minimum point (1,1)11(1,1)( 1 , 1 ), at which the Hessian is poorly-conditioned. We also consider the sum of the Rosenbrock function and a 2D sine wave function, which still has a unique global minimum at (1,1)11(1,1)( 1 , 1 ) while having many local minima, see Section D.1.

We learned models of the form (13) where 𝒢𝒢{\mathcal{G}}caligraphic_G is parameterized by MLP, i-ResNet (Behrmann et al., 2019), i-DenseNet (Perugachi-Diaz et al., 2021) and the proposed BiLipNet (8). We also trained the ICNN, a scalar-output model which is convex w.r.t. inputs (Amos et al., 2017).

From Figure 7, we have the following observations. The unconstrained MLP can achieve small test errors. However, it has many local minima near the valley y=x2𝑦superscript𝑥2y=x^{2}italic_y = italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. This phenomena is more easily visible for the Rosenbrock+Sine case but also occurs in the plain Rosenbrock case. The ICNN model has a unique global minimum but the fitting error is large as its sub-level sets are convex. For i-DenseNet, the sub-level sets become mildly non-convex but their bi-Lipschitz bound is quite conservative, so they do not capture the overall shape. In contrast, our proposed BiLipNet is more flexible and captures the non-convex shape while maintaining a unique global minimum. We note that in the Rosenbrock+Sine case, the BiLipNet surrogate has errors of similar magnitude to the MLP, but remains “easily optimizable”, i.e. it satisfies the PL condition and has a unique global minimum. Additional results are in Section D.2.

Partial PLNet.

We also fit a parameterized Rosenbrock function r(x,y;p)𝑟𝑥𝑦𝑝r(x,y;p)italic_r ( italic_x , italic_y ; italic_p ) using partial PLNet with p𝑝pitalic_p-dependent biases (see Remark 5.3). The results in Figure 8 indicate that the approach can be effective even if only bias terms are modified by the external parameter p𝑝pitalic_p, and not weights.

High-dimensional case.

We now turn to scalability of the approach to higher-dimensional problems and analyse convergence of the DYS method for computing the global minimum. We apply the approach to a N𝑁Nitalic_N=20-dimensional version of the Rosenbrock function:

R(x)=1N1i=1N1r(xi,xi+1)𝑅𝑥1𝑁1superscriptsubscript𝑖1𝑁1𝑟subscript𝑥𝑖subscript𝑥𝑖1R(x)=\frac{1}{N-1}\sum_{i=1}^{N-1}r(x_{i},x_{i+1})italic_R ( italic_x ) = divide start_ARG 1 end_ARG start_ARG italic_N - 1 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N - 1 end_POSTSUPERSCRIPT italic_r ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) (14)

which has a global minimum of zero at x=(1,1,,1)𝑥111x=(1,1,...,1)italic_x = ( 1 , 1 , … , 1 ) but is non-convex and has spurious local minima (Kok & Sandrock, 2009). We sample 10K training points uniformly over [2,2]20superscript2220[-2,2]^{20}[ - 2 , 2 ] start_POSTSUPERSCRIPT 20 end_POSTSUPERSCRIPT. Note that, in contrast to the 2D example above, this is very sparse sampling of 20-dimensional space.

A comparison of train and test error vs model distortion is shown in Figure 5. It can be seen that our proposed BiLipNet model achieves far better fits than iResNet (Behrmann et al., 2021) and iDenseNet (Perugachi-Diaz et al., 2021), which can not achieve small training error for any value of the distortion parameter. Furthermore, for our network, the distortion parameter appears to act as an effective regularizer. Note that the best test error occurs after training error drops to near zero (108similar-toabsentsuperscript108\sim 10^{-8}∼ 10 start_POSTSUPERSCRIPT - 8 end_POSTSUPERSCRIPT) but distortion is still relatively small.

Refer to caption
Figure 5: Surrogate loss learning for 20-dimensional Rosenbrock function. Comparison of training and test error vs model distortion for PLNet with different bi-Lipschitz models.

Solver comparison.

Given the surrogate loss function learned by BiLipNet, we now compare methods to compute the location of its global minimum. In Figure 6 we compare the proposed DYS solver to the forward step method (FSM), see, e.g., (Ryu & Boyd, 2016). Specifically, the inverse x=1(y)𝑥superscript1𝑦x={\mathcal{F}}^{-1}(y)italic_x = caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_y ) with {\mathcal{F}}caligraphic_F as a μ𝜇\muitalic_μ-strongly monotone and ν𝜈\nuitalic_ν-Lipschitz layer can be computed via

xk+1=xkα((xk)y)superscript𝑥𝑘1superscript𝑥𝑘𝛼superscript𝑥𝑘𝑦x^{k+1}=x^{k}-\alpha({\mathcal{F}}(x^{k})-y)italic_x start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_α ( caligraphic_F ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) - italic_y ) (15)

which has a convergence rate of 1μ2/ν21superscript𝜇2superscript𝜈21-\mu^{2}/\nu^{2}1 - italic_μ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_ν start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT if α=μ/ν2𝛼𝜇superscript𝜈2\alpha=\mu/\nu^{2}italic_α = italic_μ / italic_ν start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. We also consider a commonly used gradient-based method – ADAM (Kingma & Ba, 2015) applied directly to the surrogate loss. We take two values of the distortion parameter: τ=5𝜏5\tau=5italic_τ = 5 (optimal) and τ=50𝜏50\tau=50italic_τ = 50. In both cases, the proposed DYS method converges significantly faster than the alternatives, and the results illustrate an additional benefit of regularising via distortion, besides improving the test error: the τ=5𝜏5\tau=5italic_τ = 5 case converges significantly faster than τ=50𝜏50\tau=50italic_τ = 50.

At the computed point x=𝒢1(0)superscript𝑥superscript𝒢10x^{\star}={\mathcal{G}}^{-1}(0)italic_x start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT = caligraphic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( 0 ) for τ=5𝜏5\tau=5italic_τ = 5, the true function (14) takes a value of R(x)=0.041𝑅superscript𝑥0.041R(x^{\star})=0.041italic_R ( italic_x start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = 0.041. This is more than an order of magnitude better than the smallest value of R(x)𝑅𝑥R(x)italic_R ( italic_x ) over the training data, which ranged over [0.475,6.532]0.4756.532[0.475,6.532][ 0.475 , 6.532 ], indicating that PLNets have a useful “implicit bias” and do not simply interpolate the training data.

Refer to caption
Figure 6: Solver comparison for finding the global minimum of a PLNet. We try a range of rates [0.1,0.5,1.0,2.0,5.0]0.10.51.02.05.0[0.1,0.5,1.0,2.0,5.0][ 0.1 , 0.5 , 1.0 , 2.0 , 5.0 ] for ADAM and present the best result. The proposed back solve method with DYS algorithm (11) converges much faster than ADAM applied to f𝑓fitalic_f or back solve method with FSM algorithm (15).
Refer to caption
Figure 7: Learning a surrogate loss for the Rosenbrock and Rosenbrock+Sine functions, which is non-convex and has many local minima. The first row contains the true functions while the remaining rows show learned functions and errors for various surrogate loss models.
Refer to caption
Figure 8: Learning a parameterized Rosenbrock function r(x,y;a,b)=1/200(xa)2+0.5(ybx2)2𝑟𝑥𝑦𝑎𝑏1200superscript𝑥𝑎20.5superscript𝑦𝑏superscript𝑥22r(x,y;a,b)=1/200(x-a)^{2}+0.5(y-bx^{2})^{2}italic_r ( italic_x , italic_y ; italic_a , italic_b ) = 1 / 200 ( italic_x - italic_a ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 0.5 ( italic_y - italic_b italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT via a partial PLNet.

7 Conclusion

This paper has introduced a new bi-Lipschitz network architecture, the BiLipNet, and a new scalar-output network, the PLNet which satisfies the Polyak-Łojasiewicz condition, and is hence “easily optimizable”.

The core technical contribution is a new layer-type: the “feed-through” layer, which has certified bounds for strong monotonicity and Lipschitzness. By composing with orthogonal layers we obtain a bi-Lipschitz network structure (BiLipNet) which has much tighter bounds than existing bi-Lipschitz residual networks based on spectral normalization. The PLNet composes a BiLipNet with a quadratic output layer, and guarantees unique global minimum which is efficiently computable.

Impact Statement

There are many application domains in which the trustworthiness of machine learning is a live topic of debate and raises important and challenging questions. The goal of this paper is to advance the sub-field of machine learning methods which have mathematically-certified properties. In particular, in this paper one application is uncertainty quantification. We hope that a positive impact of our paper and others like it will be to the development of ML methods that can better satisfy societal expectations of trustworthiness and transparency.

We are not aware of any potentially significant negative impacts that are particularly associated with this line of research (models with certified properties).

References

  • Ahn et al. (2022) Ahn, B., Kim, C., Hong, Y., and Kim, H. J. Invertible monotone operators for normalizing flows. Advances in Neural Information Processing Systems, 35:16836–16848, 2022.
  • Amos et al. (2017) Amos, B., Xu, L., and Kolter, J. Z. Input convex neural networks. In International Conference on Machine Learning (ICML), pp. 146–155. PMLR, 2017.
  • Araujo et al. (2023) Araujo, A., Havens, A. J., Delattre, B., Allauzen, A., and Hu, B. A unified algebraic perspective on Lipschitz neural networks. In The Eleventh International Conference on Learning Representations (ICLR), 2023.
  • Ardizzone et al. (2018) Ardizzone, L., Kruse, J., Rother, C., and Köthe, U. Analyzing inverse problems with invertible neural networks. In International Conference on Learning Representations (ICLR), 2018.
  • Arjovsky et al. (2017) Arjovsky, M., Chintala, S., and Bottou, L. Wasserstein generative adversarial networks. In International conference on machine learning (ICML), pp. 214–223. PMLR, 2017.
  • Arora & Doshi (2021) Arora, S. and Doshi, P. A survey of inverse reinforcement learning: Challenges, methods and progress. Artificial Intelligence, 297:103500, 2021.
  • Bai et al. (2019) Bai, S., Kolter, J. Z., and Koltun, V. Deep equilibrium models. In Advances in Neural Information Processing Systems, pp. 690–701, 2019.
  • Barbara et al. (2024) Barbara, N. H., Wang, R., and Manchester, I. R. On robust reinforcement learning with lipschitz-bounded policy networks. arXiv preprint arXiv:2405.11432, 2024.
  • Bauer & Mnih (2019) Bauer, M. and Mnih, A. Resampled priors for variational autoencoders. In The 22nd International Conference on Artificial Intelligence and Statistics, pp.  66–75. PMLR, 2019.
  • Behrmann et al. (2019) Behrmann, J., Grathwohl, W., Chen, R. T., Duvenaud, D., and Jacobsen, J.-H. Invertible residual networks. In International conference on machine learning (ICML), pp. 573–582. PMLR, 2019.
  • Behrmann et al. (2021) Behrmann, J., Vicol, P., Wang, K.-C., Grosse, R., and Jacobsen, J.-H. Understanding and mitigating exploding inverses in invertible neural networks. In International Conference on Artificial Intelligence and Statistics, pp.  1792–1800. PMLR, 2021.
  • Chen et al. (2019) Chen, R. T., Behrmann, J., Duvenaud, D. K., and Jacobsen, J.-H. Residual flows for invertible generative modeling. Advances in Neural Information Processing Systems, 32, 2019.
  • Cheng et al. (2024) Cheng, J., Wang, R., and Manchester, I. R. Learning stable and passive neural differential equations. arXiv preprint arXiv:2404.12554, 2024.
  • Coleman et al. (2017) Coleman, C., Narayanan, D., Kang, D., Zhao, T., Zhang, J., Nardi, L., Bailis, P., Olukotun, K., Ré, C., and Zaharia, M. Dawnbench: An end-to-end deep learning benchmark and competition. Training, 100(101):102, 2017.
  • Cozad et al. (2014) Cozad, A., Sahinidis, N. V., and Miller, D. C. Learning surrogate models for simulation-based optimization. AIChE Journal, 60(6):2211–2227, 2014.
  • Davis & Yin (2017) Davis, D. and Yin, W. A three-operator splitting scheme and its optimization applications. Set-valued and variational analysis, 25:829–858, 2017.
  • Davis (2006) Davis, T. A. Direct methods for sparse linear systems. SIAM, 2006.
  • De Cao et al. (2020) De Cao, N., Aziz, W., and Titov, I. Block neural autoregressive flow. In Uncertainty in artificial intelligence, pp.  1263–1273. PMLR, 2020.
  • Dinh et al. (2015) Dinh, L., Krueger, D., and Bengio, Y. Nice: Non-linear independent components estimation. ICLR Workshop Track, 2015.
  • Dinh et al. (2017) Dinh, L., Sohl-Dickstein, J., and Bengio, S. Density estimation using real NVP. In International Conference on Learning Representations (ICLR), 2017.
  • El Ghaoui et al. (2021) El Ghaoui, L., Gu, F., Travacca, B., Askari, A., and Tsai, A. Implicit deep learning. SIAM Journal on Mathematics of Data Science, 3(3):930–958, 2021.
  • Fazlyab et al. (2019) Fazlyab, M., Robey, A., Hassani, H., Morari, M., and Pappas, G. Efficient and accurate estimation of Lipschitz constants for deep neural networks. In Advances in Neural Information Processing Systems, pp. 11427–11438, 2019.
  • Grathwohl et al. (2019) Grathwohl, W., Chen, R. T., Bettencourt, J., and Duvenaud, D. Scalable reversible generative models with free-form continuous dynamics. In International Conference on Learning Representations (ICLR), 2019.
  • Grudzien et al. (2024) Grudzien, K., Uehara, M., Levine, S., and Abbeel, P. Functional graphical models: Structure enables offline data-driven optimization. In International Conference on Artificial Intelligence and Statistics, pp.  2449–2457. PMLR, 2024.
  • Gu et al. (2016) Gu, S., Lillicrap, T., Sutskever, I., and Levine, S. Continuous deep Q-learning with model-based acceleration. In International conference on machine learning (ICML), pp. 2829–2838. PMLR, 2016.
  • Gulrajani et al. (2017) Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., and Courville, A. C. Improved training of Wasserstein GANs. Advances in neural information processing systems, 30, 2017.
  • Havens et al. (2023) Havens, A. J., Araujo, A., Garg, S., Khorrami, F., and Hu, B. Exploiting connections between Lipschitz structures for certifiably robust deep equilibrium models. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  • Helfrich et al. (2018) Helfrich, K., Willmott, D., and Ye, Q. Orthogonal recurrent neural networks with scaled Cayley transform. In International Conference on Machine Learning (ICML), pp. 1969–1978. PMLR, 2018.
  • Ho et al. (2019) Ho, J., Chen, X., Srinivas, A., Duan, Y., and Abbeel, P. Flow++: Improving flow-based generative models with variational dequantization and architecture design. In International Conference on Machine Learning (ICML), pp. 2722–2730. PMLR, 2019.
  • Huang et al. (2018) Huang, C.-W., Krueger, D., Lacoste, A., and Courville, A. Neural autoregressive flows. In International Conference on Machine Learning (ICML), pp. 2078–2087. PMLR, 2018.
  • Karimi et al. (2016) Karimi, H., Nutini, J., and Schmidt, M. Linear convergence of gradient and proximal-gradient methods under the Polyak-Łojasiewicz condition. In Machine Learning and Knowledge Discovery in Databases: European Conference, ECML PKDD 2016, Riva del Garda, Italy, September 19-23, 2016, Proceedings, Part I 16, pp.  795–811. Springer, 2016.
  • Kingma & Ba (2015) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations (ICLR), 2015.
  • Kingma & Dhariwal (2018) Kingma, D. P. and Dhariwal, P. Glow: Generative flow with invertible 1x1 convolutions. Advances in neural information processing systems, 31, 2018.
  • Kobyzev et al. (2020) Kobyzev, I., Prince, S. J., and Brubaker, M. A. Normalizing flows: An introduction and review of current methods. IEEE transactions on pattern analysis and machine intelligence, 43(11):3964–3979, 2020.
  • Kok & Sandrock (2009) Kok, S. and Sandrock, C. Locating and Characterizing the Stationary Points of the Extended Rosenbrock Function. Evolutionary Computation, 17(3):437–453, 09 2009.
  • Li et al. (2019) Li, J., Fang, C., and Lin, Z. Lifted proximal operator machines. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 33, pp.  4181–4188, 2019.
  • Li et al. (2020) Li, J., Li, F., and Todorovic, S. Efficient riemannian optimization on the stiefel manifold via the Cayley transform. In International Conference on Learning Representations (ICLR), 2020.
  • Liang et al. (2023) Liang, E., Chen, M., and Low, S. Low complexity homeomorphic projection to ensure neural-network solution feasibility for optimization over (non-) convex set. In International conference on machine learning (ICML). PMLR, 2023.
  • Liu et al. (2020) Liu, J., Lin, Z., Padhy, S., Tran, D., Bedrax Weiss, T., and Lakshminarayanan, B. Simple and principled uncertainty estimation with deterministic deep learning via distance awareness. Advances in Neural Information Processing Systems, 33:7498–7512, 2020.
  • Liu et al. (2023) Liu, J. Z., Padhy, S., Ren, J., Lin, Z., Wen, Y., Jerfel, G., Nado, Z., Snoek, J., Tran, D., and Lakshminarayanan, B. A simple approach to improve single-model deep uncertainty via distance-awareness. Journal of Machine Learning Research, 24(42):1–63, 2023.
  • Lojasiewicz (1963) Lojasiewicz, S. A topological property of real analytic subsets. Coll. du CNRS, Les équations aux dérivées partielles, 117(87-89):2, 1963.
  • Louizos & Welling (2017) Louizos, C. and Welling, M. Multiplicative normalizing flows for variational bayesian neural networks. In International Conference on Machine Learning (ICML), pp. 2218–2227. PMLR, 2017.
  • Lu et al. (2021) Lu, C., Chen, J., Li, C., Wang, Q., and Zhu, J. Implicit normalizing flows. In International Conference on Learning Representations (ICLR), 2021.
  • Megretski & Rantzer (1997) Megretski, A. and Rantzer, A. System analysis via integral quadratic constraints. IEEE Transactions on Automatic Control, 42(6):819–830, 1997.
  • Meunier et al. (2022) Meunier, L., Delattre, B. J., Araujo, A., and Allauzen, A. A dynamical system perspective for Lipschitz neural networks. In International Conference on Machine Learning (ICML), pp. 15484–15500. PMLR, 2022.
  • Misener & Biegler (2023) Misener, R. and Biegler, L. Formulating data-driven surrogate models for process optimization. Computers & Chemical Engineering, 179:108411, 2023.
  • Miyato et al. (2018) Miyato, T., Kataoka, T., Koyama, M., and Yoshida, Y. Spectral normalization for generative adversarial networks. In International Conference on Learning Representations (ICLR), 2018.
  • Papamakarios et al. (2021) Papamakarios, G., Nalisnick, E., Rezende, D. J., Mohamed, S., and Lakshminarayanan, B. Normalizing flows for probabilistic modeling and inference. The Journal of Machine Learning Research, 22(1):2617–2680, 2021.
  • Pauli et al. (2024) Pauli, P., Havens, A., Araujo, A., Garg, S., Khorrami, F., Allgöwer, F., and Hu, B. Novel quadratic constraints for extending LipSDP beyond slope-restricted activations. In International Conference on Learning Representations (ICLR), 2024.
  • Perugachi-Diaz et al. (2021) Perugachi-Diaz, Y., Tomczak, J., and Bhulai, S. Invertible densenets with concatenated Lipswish. Advances in Neural Information Processing Systems, 34:17246–17257, 2021.
  • Polyak (1963) Polyak, B. Gradient methods for minimizing functionals (in russian). USSR Computational Mathematics and Mathematical Physics, 3(4):643–653, 1963.
  • Prach & Lampert (2022) Prach, B. and Lampert, C. H. Almost-orthogonal layers for efficient general-purpose Lipschitz networks. In European Conference on Computer Vision, pp.  350–365. Springer, 2022.
  • Rantzer (1996) Rantzer, A. On the Kalman—Yakubovich—Popov lemma. Systems & control letters, 28(1):7–10, 1996.
  • Revay et al. (2020) Revay, M., Wang, R., and Manchester, I. R. Lipschitz bounded equilibrium networks. arXiv preprint arXiv:2010.01732, 2020.
  • Revay et al. (2023) Revay, M., Wang, R., and Manchester, I. R. Recurrent equilibrium networks: Flexible dynamic models with guaranteed stability and robustness. IEEE Transactions on Automatic Control, 2023.
  • Russo & Proutiere (2021) Russo, A. and Proutiere, A. Towards optimal attacks on reinforcement learning policies. In 2021 American Control Conference (ACC), pp.  4561–4567. IEEE, 2021.
  • Ryu & Boyd (2016) Ryu, E. K. and Boyd, S. Primer on monotone operator methods. Appl. comput. math, 15(1):3–43, 2016.
  • Ryu et al. (2019) Ryu, M., Chow, Y., Anderson, R., Tjandraatmadja, C., and Boutilier, C. CAQL: Continuous action Q-learning. In International Conference on Learning Representations (ICLR), 2019.
  • Singla & Feizi (2021) Singla, S. and Feizi, S. Skew orthogonal convolutions. In International Conference on Machine Learning (ICML), pp. 9756–9766. PMLR, 2021.
  • Singla et al. (2022) Singla, S., Singla, S., and Feizi, S. Improved deterministic l2 robustness on CIFAR-10 and CIFAR-100. In International Conference on Learning Representations (ICLR), 2022.
  • Trockman & Kolter (2021) Trockman, A. and Kolter, J. Z. Orthogonalizing convolutional layers with the Cayley transform. In International Conference on Learning Representations (ICLR), 2021.
  • Tsuzuku et al. (2018) Tsuzuku, Y., Sato, I., and Sugiyama, M. Lipschitz-margin training: Scalable certification of perturbation invariance for deep neural networks. In Advances in neural information processing systems, pp. 6541–6550, 2018.
  • Wang & Manchester (2023) Wang, R. and Manchester, I. Direct parameterization of Lipschitz-bounded deep networks. In International Conference on Machine Learning (ICML), pp. 36093–36110. PMLR, 2023.
  • Wang et al. (2022) Wang, Z., Prakriya, G., and Jha, S. A quantitative geometric approach to neural-network smoothness. Advances in Neural Information Processing Systems, 35:34201–34215, 2022.
  • Ward et al. (2019) Ward, P. N., Smofsky, A., and Bose, A. J. Improving exploration in soft-actor-critic with normalizing flows policies. ICML Workshop on Invertible Neural Networks and Normalizing Flows,, 2019.
  • Wilson (1967) Wilson, F. W. The structure of the level surfaces of a Lyapunov function. Journal of Differential Equations, 3(3):323–329, 1967.
  • Winston & Kolter (2020) Winston, E. and Kolter, J. Z. Monotone operator equilibrium networks. Advances in neural information processing systems, 33:10718–10728, 2020.
  • Yeh (2006) Yeh, J. Real analysis: theory of measure and integration second edition. World Scientific Publishing Company, 2006.
  • Zhang et al. (2021) Zhang, B., Cai, T., Lu, Z., He, D., and Wang, L. Towards certifying l-infinity robustness using neural networks with l-inf-dist neurons. In International Conference on Machine Learning, pp. 12368–12379. PMLR, 2021.

Appendix A Model Parameterization

A model parameterization is a mapping :ϕθ:italic-ϕ𝜃{\mathcal{M}}:\phi\rightarrow\thetacaligraphic_M : italic_ϕ → italic_θ where ϕNitalic-ϕsuperscript𝑁\phi\in\mathbb{R}^{N}italic_ϕ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT is a free learnable parameter while θ𝜃\thetaitalic_θ includes the model weights Um×n,Wm×m,Yn×mformulae-sequence𝑈superscript𝑚𝑛formulae-sequence𝑊superscript𝑚𝑚𝑌superscript𝑛𝑚U\in\mathbb{R}^{m\times n},W\in\mathbb{R}^{m\times m},Y\in\mathbb{R}^{n\times m}italic_U ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT , italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT , italic_Y ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_m end_POSTSUPERSCRIPT and IQC multiplier Λ𝔻+mΛsuperscriptsubscript𝔻𝑚\Lambda\in{\mathbb{D}}_{+}^{m}roman_Λ ∈ blackboard_D start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT with n,m𝑛𝑚n,mitalic_n , italic_m as the dimensions of the input and hidden units, respectively. The aim of this section is to construct a parameterization such that the large-scale SDP constraint (3) holds, i.e., Y=UΛ𝑌superscript𝑈topΛY=U^{\top}\Lambdaitalic_Y = italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ and

H=2ΛWΛΛW=[2Λ1W2Λ2Λ2W22Λ2W3Λ3ΛL1WL12ΛL1WLΛLΛLWL2ΛL]2γYY𝐻2Λsuperscript𝑊topΛΛ𝑊matrix2subscriptΛ1superscriptsubscript𝑊2topsubscriptΛ2subscriptΛ2subscript𝑊22subscriptΛ2superscriptsubscript𝑊3topsubscriptΛ3missing-subexpressionmissing-subexpressionmissing-subexpressionsubscriptΛ𝐿1subscript𝑊𝐿12subscriptΛ𝐿1superscriptsubscript𝑊𝐿topsubscriptΛ𝐿missing-subexpressionmissing-subexpressionmissing-subexpressionsubscriptΛ𝐿subscript𝑊𝐿2subscriptΛ𝐿2𝛾superscript𝑌top𝑌\begin{split}H=2\Lambda-W^{\top}\Lambda-\Lambda W=\begin{bmatrix}2\Lambda_{1}&% -{W}_{2}^{\top}\Lambda_{2}\\ -\Lambda_{2}W_{2}&2\Lambda_{2}&-{W}_{3}^{\top}\Lambda_{3}\\ &\ddots&\ddots&\ddots\\ &&-\Lambda_{L-1}W_{L-1}&2\Lambda_{L-1}&-W_{L}^{\top}\Lambda_{L}\\ &&&-\Lambda_{L}{W}_{L}&2\Lambda_{L}\end{bmatrix}\geq\frac{2}{\gamma}Y^{\top}Y% \end{split}start_ROW start_CELL italic_H = 2 roman_Λ - italic_W start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ - roman_Λ italic_W = [ start_ARG start_ROW start_CELL 2 roman_Λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL - roman_Λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL 2 roman_Λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL - italic_W start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL - roman_Λ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT end_CELL start_CELL 2 roman_Λ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT end_CELL start_CELL - italic_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL end_CELL start_CELL - roman_Λ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL start_CELL 2 roman_Λ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ≥ divide start_ARG 2 end_ARG start_ARG italic_γ end_ARG italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_Y end_CELL end_ROW (16)

Since H0succeeds-or-equals𝐻0H\succeq 0italic_H ⪰ 0 has band structure, it can be represented by H=XX𝐻𝑋superscript𝑋topH=XX^{\top}italic_H = italic_X italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (Davis, 2006). Moreover, from Lemma 3 of (Rantzer, 1996) we have that any U,Y𝑈𝑌U,Yitalic_U , italic_Y satisfying Y=UΛ𝑌superscript𝑈topΛY=U^{\top}\Lambdaitalic_Y = italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ and XX2γYYsucceeds-or-equals𝑋superscript𝑋top2𝛾superscript𝑌top𝑌XX^{\top}\succeq\frac{2}{\gamma}Y^{\top}Yitalic_X italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ divide start_ARG 2 end_ARG start_ARG italic_γ end_ARG italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_Y can be represent by

U=γ/2Λ1XQ,Y=γ/2QXformulae-sequence𝑈𝛾2superscriptΛ1𝑋𝑄𝑌𝛾2superscript𝑄topsuperscript𝑋topU=\sqrt{\gamma/2}\Lambda^{-1}XQ,\quad Y=\sqrt{\gamma/2}Q^{\top}X^{\top}italic_U = square-root start_ARG italic_γ / 2 end_ARG roman_Λ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_X italic_Q , italic_Y = square-root start_ARG italic_γ / 2 end_ARG italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (17)

where Qm×n𝑄superscript𝑚𝑛Q\in\mathbb{R}^{m\times n}italic_Q ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT with QQIprecedes-or-equals𝑄superscript𝑄top𝐼QQ^{\top}\preceq Iitalic_Q italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪯ italic_I. The remaining task is to find X𝑋Xitalic_X such that H=XX𝐻𝑋superscript𝑋topH=XX^{\top}italic_H = italic_X italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT has the same sparse structure as (16), which was solved by (Wang & Manchester, 2023). For self-contained purpose, we provide detail construction as follows. First, we further parameterize X=ΨP𝑋Ψ𝑃X=\Psi Pitalic_X = roman_Ψ italic_P, where Ψ=diag(Ψ1,,ΨL)ΨdiagsubscriptΨ1subscriptΨ𝐿\Psi=\mathrm{diag}(\Psi_{1},\ldots,\Psi_{L})roman_Ψ = roman_diag ( roman_Ψ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , roman_Ψ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) with Ψk𝔻+mksubscriptΨ𝑘superscriptsubscript𝔻subscript𝑚𝑘\Psi_{k}\in{\mathbb{D}}_{+}^{m_{k}}roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_D start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and

P=[A1B2A2BLAL].𝑃matrixsubscript𝐴1subscript𝐵2subscript𝐴2missing-subexpressionmissing-subexpressionmissing-subexpressionsubscript𝐵𝐿subscript𝐴𝐿\begin{split}P=\begin{bmatrix}A_{1}\\ -B_{2}&A_{2}\\ &\ddots&\ddots\\ &&-B_{L}&A_{L}\end{bmatrix}.\end{split}start_ROW start_CELL italic_P = [ start_ARG start_ROW start_CELL italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL - italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL - italic_B start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL start_CELL italic_A start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] . end_CELL end_ROW

By comparing H=ΨPPΨ𝐻Ψ𝑃superscript𝑃topΨH=\Psi PP^{\top}\Psiitalic_H = roman_Ψ italic_P italic_P start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ψ with (16) we have

Hkk=Ψk(BkBk+AkAk)Ψk=2Λk,Hk1,k=ΨkBkAk1=ΛkWk,formulae-sequencesubscript𝐻𝑘𝑘subscriptΨ𝑘subscript𝐵𝑘superscriptsubscript𝐵𝑘topsubscript𝐴𝑘superscriptsubscript𝐴𝑘topsubscriptΨ𝑘2subscriptΛ𝑘subscript𝐻𝑘1𝑘subscriptΨ𝑘subscript𝐵𝑘superscriptsubscript𝐴𝑘1topsubscriptΛ𝑘subscript𝑊𝑘\displaystyle H_{kk}=\Psi_{k}(B_{k}B_{k}^{\top}+A_{k}A_{k}^{\top})\Psi_{k}=2% \Lambda_{k},\quad H_{k-1,k}=-\Psi_{k}B_{k}A_{k-1}^{\top}=-\Lambda_{k}W_{k},italic_H start_POSTSUBSCRIPT italic_k italic_k end_POSTSUBSCRIPT = roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = 2 roman_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_H start_POSTSUBSCRIPT italic_k - 1 , italic_k end_POSTSUBSCRIPT = - roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = - roman_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ,

which further leads to

Ψk2=2Λk,BkBk+AkAk=I,Wk=2Ψk1BkAk1Ψk1k=1,,L,formulae-sequencesuperscriptsubscriptΨ𝑘22subscriptΛ𝑘formulae-sequencesubscript𝐵𝑘superscriptsubscript𝐵𝑘topsubscript𝐴𝑘superscriptsubscript𝐴𝑘top𝐼formulae-sequencesubscript𝑊𝑘2superscriptsubscriptΨ𝑘1subscript𝐵𝑘superscriptsubscript𝐴𝑘1topsubscriptΨ𝑘1𝑘1𝐿\displaystyle\Psi_{k}^{2}=2\Lambda_{k},\quad B_{k}B_{k}^{\top}+A_{k}A_{k}^{% \top}=I,\quad W_{k}=2\Psi_{k}^{-1}B_{k}A_{k-1}^{\top}\Psi_{k-1}\quad k=1,% \ldots,L,roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 2 roman_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = italic_I , italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = 2 roman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ψ start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT italic_k = 1 , … , italic_L , (18)

with B1=0subscript𝐵10B_{1}=0italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0. We have converted the large-scale SDP constraint (16) into many simple and small-scale constraints such as

Ψk2=2Λk,RkRk=I,QQIformulae-sequencesuperscriptsubscriptΨ𝑘22subscriptΛ𝑘formulae-sequencesubscript𝑅𝑘superscriptsubscript𝑅𝑘top𝐼precedes-or-equals𝑄superscript𝑄top𝐼\Psi_{k}^{2}=2\Lambda_{k},\quad R_{k}R_{k}^{\top}=I,\quad QQ^{\top}\preceq Iroman_Ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 2 roman_Λ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = italic_I , italic_Q italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪯ italic_I (19)

with Rk=[BkAk]subscript𝑅𝑘matrixsubscript𝐵𝑘subscript𝐴𝑘R_{k}=\begin{bmatrix}B_{k}&A_{k}\end{bmatrix}italic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = [ start_ARG start_ROW start_CELL italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ], which further can be easily parameterized via the Cayley transformation (4), see Section 3.3. The Cayley transformation has been applied to construct orthogonal layers (Helfrich et al., 2018; Li et al., 2020; Trockman & Kolter, 2021) and 1-Lipschitz Sandwich layer (Wang & Manchester, 2023).

An equivalent model representation.

The model weights U,Y,W𝑈𝑌𝑊U,Y,Witalic_U , italic_Y , italic_W defined in (5) can be rewritten as U=2γΨ1S𝑈2𝛾superscriptΨ1𝑆U=\sqrt{2\gamma}\Psi^{-1}Sitalic_U = square-root start_ARG 2 italic_γ end_ARG roman_Ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_S, Y=γ/2SΨ1𝑌𝛾2superscript𝑆topsuperscriptΨ1Y=\sqrt{\gamma/2}S^{\top}\Psi^{-1}italic_Y = square-root start_ARG italic_γ / 2 end_ARG italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT and W=Ψ1WΨ𝑊superscriptΨ1𝑊ΨW=\Psi^{-1}W\Psiitalic_W = roman_Ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_W roman_Ψ with

S=[S1S2SL]=[A1Q1A2Q2B2Q1ALQLBLQL1],V=[0V20VL0]=[02B2A102BLAL10]formulae-sequence𝑆matrixsubscript𝑆1subscript𝑆2subscript𝑆𝐿matrixsubscript𝐴1subscript𝑄1subscript𝐴2subscript𝑄2subscript𝐵2subscript𝑄1subscript𝐴𝐿subscript𝑄𝐿subscript𝐵𝐿subscript𝑄𝐿1𝑉matrix0missing-subexpressionsubscript𝑉20missing-subexpressionmissing-subexpressionmissing-subexpressionsubscript𝑉𝐿0matrix0missing-subexpression2subscript𝐵2superscriptsubscript𝐴1top0missing-subexpressionmissing-subexpressionmissing-subexpression2subscript𝐵𝐿superscriptsubscript𝐴𝐿1top0S=\begin{bmatrix}S_{1}\\ S_{2}\\ \vdots\\ S_{L}\end{bmatrix}=\begin{bmatrix}A_{1}Q_{1}\\ A_{2}Q_{2}-B_{2}Q_{1}\\ \vdots\\ A_{L}Q_{L}-B_{L}Q_{L-1}\end{bmatrix},\quad V=\begin{bmatrix}0&\\ V_{2}&0\\ &\ddots&\ddots\\ &&V_{L}&0\end{bmatrix}=\begin{bmatrix}0&\\ 2B_{2}A_{1}^{\top}&0\\ &\ddots&\ddots\\ &&2B_{L}A_{L-1}^{\top}&0\end{bmatrix}italic_S = [ start_ARG start_ROW start_CELL italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_S start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_A start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT - italic_B start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_Q start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] , italic_V = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_V start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL italic_V start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL 0 end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL 2 italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL 2 italic_B start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL 0 end_CELL end_ROW end_ARG ] (20)

where Q=[Q1QL]𝑄superscriptmatrixsuperscriptsubscript𝑄1topsuperscriptsubscript𝑄𝐿toptopQ=\begin{bmatrix}Q_{1}^{\top}&\cdots&Q_{L}^{\top}\end{bmatrix}^{\top}italic_Q = [ start_ARG start_ROW start_CELL italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_Q start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Then, the network (2) can be written as

z=σ(Ψ1VΨz+2γΨ1Sx+b),y=μx+γ/2SΨz.formulae-sequence𝑧𝜎superscriptΨ1𝑉Ψ𝑧2𝛾superscriptΨ1𝑆𝑥𝑏𝑦𝜇𝑥𝛾2superscript𝑆topΨ𝑧z=\sigma(\Psi^{-1}V\Psi z+\sqrt{2\gamma}\Psi^{-1}Sx+b),\quad y=\mu x+\sqrt{% \gamma/2}S^{\top}\Psi z.italic_z = italic_σ ( roman_Ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_V roman_Ψ italic_z + square-root start_ARG 2 italic_γ end_ARG roman_Ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_S italic_x + italic_b ) , italic_y = italic_μ italic_x + square-root start_ARG italic_γ / 2 end_ARG italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ψ italic_z . (21)

By introducing the new hidden state z^=Ψz^𝑧Ψ𝑧\hat{z}=\Psi zover^ start_ARG italic_z end_ARG = roman_Ψ italic_z and bias b^=Ψb^𝑏Ψ𝑏\hat{b}=\Psi bover^ start_ARG italic_b end_ARG = roman_Ψ italic_b, we obtain an equivalent form:

z^=σ^(Vz^+2γSx+b^),y=μx+γ/2Sz^+by.formulae-sequence^𝑧^𝜎𝑉^𝑧2𝛾𝑆𝑥^𝑏𝑦𝜇𝑥𝛾2superscript𝑆top^𝑧subscript𝑏𝑦\hat{z}=\hat{\sigma}\bigl{(}V\hat{z}+\sqrt{2\gamma}Sx+\hat{b}\bigr{)},\quad y=% \mu x+\sqrt{\gamma/2}S^{\top}\hat{z}+b_{y}.over^ start_ARG italic_z end_ARG = over^ start_ARG italic_σ end_ARG ( italic_V over^ start_ARG italic_z end_ARG + square-root start_ARG 2 italic_γ end_ARG italic_S italic_x + over^ start_ARG italic_b end_ARG ) , italic_y = italic_μ italic_x + square-root start_ARG italic_γ / 2 end_ARG italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG italic_z end_ARG + italic_b start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT . (22)

This representation is useful for computing the model inverse via monotone operator splitting, see Appendix B. We now give a lemma which will be used later for proving some propositions.

Lemma A.1.

For the matrices V,S𝑉𝑆V,Sitalic_V , italic_S defined in (20) we have

2IVV0,2ISS0.formulae-sequencesucceeds-or-equals2𝐼𝑉superscript𝑉top0succeeds-or-equals2𝐼𝑆superscript𝑆top02I-V-V^{\top}\succeq 0,\quad 2I-SS^{\top}\succeq 0.2 italic_I - italic_V - italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 0 , 2 italic_I - italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 0 . (23)
Proof.

First, we have

2I(V+V)=2[IA1B2B2A1IA2B3B3A2]02𝐼𝑉superscript𝑉top2matrix𝐼subscript𝐴1superscriptsubscript𝐵2topsubscript𝐵2superscriptsubscript𝐴1top𝐼subscript𝐴2superscriptsubscript𝐵3topmissing-subexpressionsuperscriptsubscript𝐵3topsubscript𝐴2missing-subexpressionmissing-subexpressionsucceeds-or-equals02I-(V+V^{\top})=2\begin{bmatrix}I&-A_{1}B_{2}^{\top}\\ -B_{2}A_{1}^{\top}&I&-A_{2}B_{3}^{\top}\\ &-B_{3}^{\top}A_{2}&\ddots&\ddots\\ &&\ddots&\ddots\end{bmatrix}\succeq 02 italic_I - ( italic_V + italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) = 2 [ start_ARG start_ROW start_CELL italic_I end_CELL start_CELL - italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL - italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL italic_I end_CELL start_CELL - italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL - italic_B start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW end_ARG ] ⪰ 0

where the inequality is obtained by sequentially applying the fact AkAk+BkBk=Isubscript𝐴𝑘superscriptsubscript𝐴𝑘topsubscript𝐵𝑘superscriptsubscript𝐵𝑘top𝐼A_{k}A_{k}^{\top}+B_{k}B_{k}^{\top}=Iitalic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = italic_I and Schur complement to the top diagonal block. For the inequality on S𝑆Sitalic_S, we have

2ISS=2IPQQP2IPP=2I[A1B2A2BLAL][A1B2A2BLAL]=[IA1B2B2A1IA2B3B3A2]0.2𝐼𝑆superscript𝑆top2𝐼𝑃𝑄superscript𝑄topsuperscript𝑃topsucceeds-or-equals2𝐼𝑃superscript𝑃top2𝐼matrixsubscript𝐴1subscript𝐵2subscript𝐴2missing-subexpressionmissing-subexpressionmissing-subexpressionsubscript𝐵𝐿subscript𝐴𝐿superscriptmatrixsubscript𝐴1subscript𝐵2subscript𝐴2missing-subexpressionmissing-subexpressionmissing-subexpressionsubscript𝐵𝐿subscript𝐴𝐿topmatrix𝐼subscript𝐴1superscriptsubscript𝐵2topsubscript𝐵2superscriptsubscript𝐴1top𝐼subscript𝐴2superscriptsubscript𝐵3topmissing-subexpressionsuperscriptsubscript𝐵3topsubscript𝐴2missing-subexpressionmissing-subexpressionsucceeds-or-equals0\begin{split}2I-SS^{\top}&=2I-PQQ^{\top}P^{\top}\succeq 2I-PP^{\top}\\ &=2I-\begin{bmatrix}A_{1}\\ -B_{2}&A_{2}\\ &\ddots&\ddots\\ &&-B_{L}&A_{L}\end{bmatrix}\begin{bmatrix}A_{1}\\ -B_{2}&A_{2}\\ &\ddots&\ddots\\ &&-B_{L}&A_{L}\end{bmatrix}^{\top}=\begin{bmatrix}I&A_{1}B_{2}^{\top}\\ B_{2}A_{1}^{\top}&I&A_{2}B_{3}^{\top}\\ &B_{3}^{\top}A_{2}&\ddots&\ddots\\ &&\ddots&\ddots\end{bmatrix}\succeq 0.\end{split}start_ROW start_CELL 2 italic_I - italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL = 2 italic_I - italic_P italic_Q italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_P start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 2 italic_I - italic_P italic_P start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = 2 italic_I - [ start_ARG start_ROW start_CELL italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL - italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL - italic_B start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL start_CELL italic_A start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL - italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL - italic_B start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL start_CELL italic_A start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = [ start_ARG start_ROW start_CELL italic_I end_CELL start_CELL italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_B start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL italic_I end_CELL start_CELL italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL italic_B start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW end_ARG ] ⪰ 0 . end_CELL end_ROW

Similarly, the last inequality can be established by sequentially applying the Schur complement to the top diagonal block. ∎

Appendix B Monotone Operator Splitting for Computing Model Inverse

Inspired by (Winston & Kolter, 2020; Revay et al., 2020), we try to compute x=1(y)𝑥superscript1𝑦x={\mathcal{F}}^{-1}(y)italic_x = caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_y ) via an operator splitting method. We first present some background of monotone operator theory based on the survey (Ryu & Boyd, 2016), and then reformulate the model inverse as a three-operator splitting problem.

B.1 Monotone operator

An operator is a set-valued or single-valued map defined by a subset of the space 𝒜n×n𝒜superscript𝑛superscript𝑛{\mathcal{A}}\subseteq\mathbb{R}^{n}\times\mathbb{R}^{n}caligraphic_A ⊆ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT; we use the notation 𝒜(x)={y(x,y)𝒜}𝒜𝑥conditional-set𝑦𝑥𝑦𝒜{\mathcal{A}}(x)=\{y\mid(x,y)\in{\mathcal{A}}\}caligraphic_A ( italic_x ) = { italic_y ∣ ( italic_x , italic_y ) ∈ caligraphic_A }. For example, the affine operator is defined by (x)={(x,Wx+b)xn}𝑥conditional-set𝑥𝑊𝑥𝑏𝑥superscript𝑛{\mathcal{L}}(x)=\{(x,Wx+b)\mid x\in\mathbb{R}^{n}\}caligraphic_L ( italic_x ) = { ( italic_x , italic_W italic_x + italic_b ) ∣ italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT }. Another important example is the subdifferential operator f={(x,f(x))}𝑓𝑥𝑓𝑥\partial f=\{(x,\partial f(x))\}∂ italic_f = { ( italic_x , ∂ italic_f ( italic_x ) ) } for a proper function f:n{}:𝑓superscript𝑛f:\mathbb{R}^{n}\rightarrow\mathbb{R}\cup\{\infty\}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R ∪ { ∞ } with f(z)=𝑓𝑧f(z)=\inftyitalic_f ( italic_z ) = ∞ for z𝐝𝐨𝐦f𝑧𝐝𝐨𝐦𝑓z\notin\mathop{\bf dom}fitalic_z ∉ bold_dom italic_f, where f(x)={gnf(y)f(x)+yx,g,yn}𝑓𝑥conditional-set𝑔superscript𝑛formulae-sequence𝑓𝑦𝑓𝑥𝑦𝑥𝑔for-all𝑦superscript𝑛\partial f(x)=\{g\in\mathbb{R}^{n}\mid f(y)\geq f(x)+\left\langle{y-x},{g}% \right\rangle,\,\forall y\in\mathbb{R}^{n}\}∂ italic_f ( italic_x ) = { italic_g ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∣ italic_f ( italic_y ) ≥ italic_f ( italic_x ) + ⟨ italic_y - italic_x , italic_g ⟩ , ∀ italic_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT }. An operator 𝒜𝒜{\mathcal{A}}caligraphic_A has a Lipschitz bound of L𝐿Litalic_L if uvLxynorm𝑢𝑣𝐿norm𝑥𝑦\|u-v\|\leq L\|x-y\|∥ italic_u - italic_v ∥ ≤ italic_L ∥ italic_x - italic_y ∥ for all (x,u),(y,v)𝒜𝑥𝑢𝑦𝑣𝒜(x,u),(y,v)\in{\mathcal{A}}( italic_x , italic_u ) , ( italic_y , italic_v ) ∈ caligraphic_A. It is non-expansive if L=1𝐿1L=1italic_L = 1 and contractive if L<1𝐿1L<1italic_L < 1. 𝒜𝒜{\mathcal{A}}caligraphic_A is strongly monotone with m>0𝑚0m>0italic_m > 0 if

uv,xymxy,(x,u),(y,v)𝒜.formulae-sequence𝑢𝑣𝑥𝑦𝑚norm𝑥𝑦for-all𝑥𝑢𝑦𝑣𝒜\left\langle{u-v},{x-y}\right\rangle\geq m\|x-y\|,\quad\forall(x,u),(y,v)\in{% \mathcal{A}}.⟨ italic_u - italic_v , italic_x - italic_y ⟩ ≥ italic_m ∥ italic_x - italic_y ∥ , ∀ ( italic_x , italic_u ) , ( italic_y , italic_v ) ∈ caligraphic_A . (24)

If the above inequality holds for m=0𝑚0m=0italic_m = 0, we call 𝒜𝒜{\mathcal{A}}caligraphic_A a monotone operator. Similarly, 𝒜𝒜{\mathcal{A}}caligraphic_A is said to be inverse monotone with ρ𝜌\rhoitalic_ρ if uv,xyρuv,(x,u),(y,v)𝒜formulae-sequence𝑢𝑣𝑥𝑦𝜌norm𝑢𝑣for-all𝑥𝑢𝑦𝑣𝒜\left\langle{u-v},{x-y}\right\rangle\geq\rho\|u-v\|,\ \forall(x,u),(y,v)\in{% \mathcal{A}}⟨ italic_u - italic_v , italic_x - italic_y ⟩ ≥ italic_ρ ∥ italic_u - italic_v ∥ , ∀ ( italic_x , italic_u ) , ( italic_y , italic_v ) ∈ caligraphic_A. An operator is called maximal monotone if no other monotone operator strictly contains it. The linear operator {\mathcal{L}}caligraphic_L is m𝑚mitalic_m-strongly monotone if W+W2mIsucceeds-or-equals𝑊superscript𝑊top2𝑚𝐼W+W^{\top}\succeq 2mIitalic_W + italic_W start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 2 italic_m italic_I, and ρ𝜌\rhoitalic_ρ-inverse monotone if W+W2ρWWsucceeds-or-equals𝑊superscript𝑊top2𝜌superscript𝑊top𝑊W+W^{\top}\succeq 2\rho W^{\top}Witalic_W + italic_W start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 2 italic_ρ italic_W start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_W. A subdifferential f𝑓\partial f∂ italic_f is maximal monotone if and only if f𝑓fitalic_f is a convex closed proper (CCP) function. Here are some basic operations for operators:

  • the operator sum 𝒜+={(x,y+z)(x,y)𝒜,(x,z)}𝒜conditional-set𝑥𝑦𝑧formulae-sequence𝑥𝑦𝒜𝑥𝑧{\mathcal{A}}+{\mathcal{B}}=\{(x,y+z)\mid(x,y)\in{\mathcal{A}},\,(x,z)\in{% \mathcal{B}}\}caligraphic_A + caligraphic_B = { ( italic_x , italic_y + italic_z ) ∣ ( italic_x , italic_y ) ∈ caligraphic_A , ( italic_x , italic_z ) ∈ caligraphic_B };

  • the composition 𝒜={(x,z)ys.t.(x,y)𝒜,(y,z)}𝒜conditional-set𝑥𝑧formulae-sequence𝑦stformulae-sequence𝑥𝑦𝒜𝑦𝑧{\mathcal{A}}{\mathcal{B}}=\{(x,z)\mid\exists y\;\mathrm{s.t.}\;(x,y)\in{% \mathcal{A}},(y,z)\in{\mathcal{B}}\}caligraphic_A caligraphic_B = { ( italic_x , italic_z ) ∣ ∃ italic_y roman_s . roman_t . ( italic_x , italic_y ) ∈ caligraphic_A , ( italic_y , italic_z ) ∈ caligraphic_B } ;

  • the inverse operator 𝒜1={(y,x)(x,y)𝒜}superscript𝒜1conditional-set𝑦𝑥𝑥𝑦𝒜{\mathcal{A}}^{-1}=\{(y,x)\mid(x,y)\in{\mathcal{A}}\}caligraphic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = { ( italic_y , italic_x ) ∣ ( italic_x , italic_y ) ∈ caligraphic_A };

  • the resolvent operator R𝒜=(I+α𝒜)1subscript𝑅𝒜superscript𝐼𝛼𝒜1R_{{\mathcal{A}}}=(I+\alpha{\mathcal{A}})^{-1}italic_R start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT = ( italic_I + italic_α caligraphic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT with α>0𝛼0\alpha>0italic_α > 0;

  • the Cayley operator C𝒜=2R𝒜Isubscript𝐶𝒜2subscript𝑅𝒜𝐼C_{{\mathcal{A}}}=2R_{{\mathcal{A}}}-Iitalic_C start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT = 2 italic_R start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT - italic_I.

Note that the resolvent and Cayley operators are non-expansive for any maximal monotone 𝒜𝒜{\mathcal{A}}caligraphic_A, and are contractive if 𝒜𝒜{\mathcal{A}}caligraphic_A is strongly monotone. For a linear operator {\mathcal{L}}caligraphic_L we have R(x)=(I+αW)1(xαb)subscript𝑅𝑥superscript𝐼𝛼𝑊1𝑥𝛼𝑏R_{{\mathcal{L}}}(x)=(I+\alpha W)^{-1}(x-\alpha b)italic_R start_POSTSUBSCRIPT caligraphic_L end_POSTSUBSCRIPT ( italic_x ) = ( italic_I + italic_α italic_W ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x - italic_α italic_b ). For a subdifferential operator f𝑓\partial f∂ italic_f, its resolvent is Rf(x)=𝐩𝐫𝐨𝐱fα(x):=argminz1/2xz+αf(z)subscript𝑅𝑓𝑥superscriptsubscript𝐩𝐫𝐨𝐱𝑓𝛼𝑥assignsubscriptargmin𝑧12norm𝑥𝑧𝛼𝑓𝑧R_{\partial f}(x)=\mathbf{prox}_{f}^{\alpha}(x):=\operatorname*{arg\,min}_{z}1% /2\|x-z\|+\alpha f(z)italic_R start_POSTSUBSCRIPT ∂ italic_f end_POSTSUBSCRIPT ( italic_x ) = bold_prox start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT ( italic_x ) := start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT 1 / 2 ∥ italic_x - italic_z ∥ + italic_α italic_f ( italic_z ), which is also called the proximal operator.

Activation as proximal operator.

As shown in (Li et al., 2019; Revay et al., 2020), many popular slope-restricted scalar activation functions can also be treated as proximal operators. To be specific, if σ::𝜎\sigma:\mathbb{R}\rightarrow\mathbb{R}italic_σ : blackboard_R → blackboard_R is slope-restricted in [0,1]01[0,1][ 0 , 1 ], then there exists a convex proper function f𝑓fitalic_f such that σ()=𝐩𝐫𝐨𝐱f1()𝜎superscriptsubscript𝐩𝐫𝐨𝐱𝑓1\sigma(\cdot)=\mathbf{prox}_{f}^{1}(\cdot)italic_σ ( ⋅ ) = bold_prox start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( ⋅ ). For self-contained purpose, we provide a list of common activations and their associated convex proper functions in Table 2, which can also be found in (Revay et al., 2020; Li et al., 2019).

Table 2: A list of common activation functions and their associated convex proper f(z)𝑓𝑧f(z)italic_f ( italic_z ) whose proximal operator is σ(x)𝜎𝑥\sigma(x)italic_σ ( italic_x ) (Revay et al., 2020). For z𝐝𝐨𝐦f𝑧𝐝𝐨𝐦𝑓z\notin\mathop{\bf dom}fitalic_z ∉ bold_dom italic_f, we have f(z)=𝑓𝑧f(z)=\inftyitalic_f ( italic_z ) = ∞. In the case of Softplus activation, Lis(z)subscriptLi𝑠𝑧\mathrm{Li}_{s}(z)roman_Li start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ( italic_z ) is the polylogarithm function.
Activation σ(x)𝜎𝑥\sigma(x)italic_σ ( italic_x ) Convex f(z)𝑓𝑧f(z)italic_f ( italic_z ) 𝐝𝐨𝐦f𝐝𝐨𝐦𝑓\mathop{\bf dom}fbold_dom italic_f
ReLu max(x,0)𝑥0\max(x,0)roman_max ( italic_x , 0 ) 00 [0,)0[0,\infty)[ 0 , ∞ )
LeakyReLu max(x,0.01x)𝑥0.01𝑥\max(x,0.01x)roman_max ( italic_x , 0.01 italic_x ) 992min(z,0)2\frac{99}{2}\min(z,0)^{2}divide start_ARG 99 end_ARG start_ARG 2 end_ARG roman_min ( italic_z , 0 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT \mathbb{R}blackboard_R
Tanh tanh(x)𝑥\tanh(x)roman_tanh ( italic_x ) 12[ln(1z2)+zln(1+z1z)z2]12delimited-[]1superscript𝑧2𝑧1𝑧1𝑧superscript𝑧2\frac{1}{2}\left[\ln(1-z^{2})+z\ln\left(\frac{1+z}{1-z}\right)-z^{2}\right]divide start_ARG 1 end_ARG start_ARG 2 end_ARG [ roman_ln ( 1 - italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + italic_z roman_ln ( divide start_ARG 1 + italic_z end_ARG start_ARG 1 - italic_z end_ARG ) - italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (1,1)11(-1,1)( - 1 , 1 )
Sigmoid 1/(1+ex)11superscript𝑒𝑥1/(1+e^{-x})1 / ( 1 + italic_e start_POSTSUPERSCRIPT - italic_x end_POSTSUPERSCRIPT ) zlnz+(1z)ln(1z)z22𝑧𝑧1𝑧1𝑧superscript𝑧22z\ln z+(1-z)\ln(1-z)-\frac{z^{2}}{2}italic_z roman_ln italic_z + ( 1 - italic_z ) roman_ln ( 1 - italic_z ) - divide start_ARG italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG (0,1)01(0,1)( 0 , 1 )
Arctan arctan(x)𝑥\arctan(x)roman_arctan ( italic_x ) ln(|cosz|)z22𝑧superscript𝑧22-\ln(|\cos z|)-\frac{z^{2}}{2}- roman_ln ( | roman_cos italic_z | ) - divide start_ARG italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG (1,1)11(-1,1)( - 1 , 1 )
Softplus ln(1+ex)1superscript𝑒𝑥\ln(1+e^{x})roman_ln ( 1 + italic_e start_POSTSUPERSCRIPT italic_x end_POSTSUPERSCRIPT ) Li2(ez)iπzz2/2subscriptLi2superscript𝑒𝑧𝑖𝜋𝑧superscript𝑧22-\mathrm{Li}_{2}(e^{z})-i\pi z-z^{2}/2- roman_Li start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_z end_POSTSUPERSCRIPT ) - italic_i italic_π italic_z - italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2 (0,)0(0,\infty)( 0 , ∞ )

B.2 Operator splitting

Many optimization problems (e.g. convex optimization) can be formulated as one of finding a zero of an appropriate monotone operator {\mathcal{F}}caligraphic_F, i.e., find xn𝑥superscript𝑛x\in\mathbb{R}^{n}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT such that 0(x)0𝑥0\in{\mathcal{F}}(x)0 ∈ caligraphic_F ( italic_x ). Note that x𝑥xitalic_x is a solution if and only if it is a fixed point x=𝒯(x)𝑥𝒯𝑥x={\mathcal{T}}(x)italic_x = caligraphic_T ( italic_x ) with 𝒯=Iα𝒯𝐼𝛼{\mathcal{T}}=I-\alpha{\mathcal{F}}caligraphic_T = italic_I - italic_α caligraphic_F for any nonzero α𝛼\alpha\in\mathbb{R}italic_α ∈ blackboard_R. The corresponding fixed point iteration is xk+1=𝒯(xk)=xkα(xk)superscript𝑥𝑘1𝒯superscript𝑥𝑘superscript𝑥𝑘𝛼superscript𝑥𝑘x^{k+1}={\mathcal{T}}(x^{k})=x^{k}-\alpha{\mathcal{F}}(x^{k})italic_x start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = caligraphic_T ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) = italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_α caligraphic_F ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ). If {\mathcal{F}}caligraphic_F is m𝑚mitalic_m-strongly monotone and L𝐿Litalic_L-Lipschitz, then this iteration converges by choosing α(0,2m/L2)𝛼02𝑚superscript𝐿2\alpha\in(0,2m/L^{2})italic_α ∈ ( 0 , 2 italic_m / italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). The optimal convergence rate is 1(m/L)21superscript𝑚𝐿21-(m/L)^{2}1 - ( italic_m / italic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, given by α=m/L2𝛼𝑚superscript𝐿2\alpha=m/L^{2}italic_α = italic_m / italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

If {\mathcal{F}}caligraphic_F contains some non-smooth components, we then split {\mathcal{F}}caligraphic_F into two or three maximal operators:

two-operator splitting problem: 0𝒜(x)+(x)0𝒜𝑥𝑥\displaystyle 0\in{\mathcal{A}}(x)+{\mathcal{B}}(x)0 ∈ caligraphic_A ( italic_x ) + caligraphic_B ( italic_x ) (25)
three-operator splitting problem: 0𝒜(x)+(x)+𝒞(x)0𝒜𝑥𝑥𝒞𝑥\displaystyle 0\in{\mathcal{A}}(x)+{\mathcal{B}}(x)+{\mathcal{C}}(x)0 ∈ caligraphic_A ( italic_x ) + caligraphic_B ( italic_x ) + caligraphic_C ( italic_x ) (26)

where 𝒜,,𝒜{\mathcal{A}},{\mathcal{B}},caligraphic_A , caligraphic_B , and 𝒞𝒞{\mathcal{C}}caligraphic_C are maximal monotone. The main benefit of such splitting is that the resolvent or Cayley operators for individual operator are easy to evaluate, which further leads to more computationally efficient algorithms. For two-operator splitting problem, some popular algorithms include

  • forward-backward splitting (FBS) x=R(Iα𝒜)(x)𝑥subscript𝑅𝐼𝛼𝒜𝑥x=R_{{\mathcal{B}}}(I-\alpha{\mathcal{A}})(x)italic_x = italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_I - italic_α caligraphic_A ) ( italic_x )

  • forward-backward-forward splitting (FBFS) x=((Iα𝒜)R(Iα𝒜)+α𝒜)(x)𝑥𝐼𝛼𝒜subscript𝑅𝐼𝛼𝒜𝛼𝒜𝑥x=((I-\alpha{\mathcal{A}})R_{{\mathcal{B}}}(I-\alpha{\mathcal{A}})+\alpha{% \mathcal{A}})(x)italic_x = ( ( italic_I - italic_α caligraphic_A ) italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_I - italic_α caligraphic_A ) + italic_α caligraphic_A ) ( italic_x )

  • Peaceman-Rachford splitting (PRS) z=C𝒜C(z),x=R(z)formulae-sequence𝑧subscript𝐶𝒜subscript𝐶𝑧𝑥subscript𝑅𝑧z=C_{{\mathcal{A}}}C_{{\mathcal{B}}}(z),\;x=R_{{\mathcal{B}}}(z)italic_z = italic_C start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_z ) , italic_x = italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_z )

  • Douglas-Rachford splitting (DRS) z=(1/2I+1/2C𝒜C)(z),x=R(z)formulae-sequence𝑧12𝐼12subscript𝐶𝒜subscript𝐶𝑧𝑥subscript𝑅𝑧z=(1/2I+1/2C_{{\mathcal{A}}}C_{{\mathcal{B}}})(z),\;x=R_{{\mathcal{B}}}(z)italic_z = ( 1 / 2 italic_I + 1 / 2 italic_C start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ) ( italic_z ) , italic_x = italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_z )

where the corresponding fixed-point iterations, the choices of hyper-parameter α𝛼\alphaitalic_α and convergence results can be found in (Ryu & Boyd, 2016). For three-operator splitting problem, the Davis-Yin splitting (DYS) (Davis & Yin, 2017) can be expressed by z=𝒯(z),x=R(z)formulae-sequence𝑧𝒯𝑧𝑥subscript𝑅𝑧z={\mathcal{T}}(z),\,x=R_{{\mathcal{B}}}(z)italic_z = caligraphic_T ( italic_z ) , italic_x = italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_z ) where 𝒯=C𝒜(Cα𝒞R)α𝒞R𝒯subscript𝐶𝒜subscript𝐶𝛼𝒞subscript𝑅𝛼𝒞subscript𝑅{\mathcal{T}}=C_{{\mathcal{A}}}(C_{{\mathcal{B}}}-\alpha{\mathcal{C}}R_{{% \mathcal{B}}})-\alpha{\mathcal{C}}R_{{\mathcal{B}}}caligraphic_T = italic_C start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT ( italic_C start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT - italic_α caligraphic_C italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ) - italic_α caligraphic_C italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT.

B.3 Operator splitting perspective for 1superscript1{\mathcal{F}}^{-1}caligraphic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT

As shown in the proof of Proposition 4.2, By applying the forward-backward splitting with parameter α=1𝛼1\alpha=1italic_α = 1, we can compute the solution z𝑧zitalic_z via the following iteration:

zk+1=R(zk𝒜^(zk))=σ^((Vγ/μSS)zk+bz).superscript𝑧𝑘1subscript𝑅superscript𝑧𝑘^𝒜superscript𝑧𝑘^𝜎𝑉𝛾𝜇𝑆superscript𝑆topsuperscript𝑧𝑘subscript𝑏𝑧\begin{split}z^{k+1}=R_{{\mathcal{B}}}(z^{k}-\widehat{{\mathcal{A}}}\bigl{(}z^% {k}\bigr{)})=\hat{\sigma}\left(\left(V-\gamma/\mu SS^{\top}\right)z^{k}+b_{z}% \right).\end{split}start_ROW start_CELL italic_z start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - over^ start_ARG caligraphic_A end_ARG ( italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ) = over^ start_ARG italic_σ end_ARG ( ( italic_V - italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) . end_CELL end_ROW

It is worth pointing out that the above iteration may not converge for the choice of α=1𝛼1\alpha=1italic_α = 1. In practice we often use more stable and faster two-operator splitting algorithms (e.g., PRS or DRS), see (Winston & Kolter, 2020; Revay et al., 2020). In this work, the motivation for further decomposing the monotone operator 𝒜^^𝒜\widehat{\mathcal{A}}over^ start_ARG caligraphic_A end_ARG into two monotone operators 𝒜,𝒞𝒜𝒞{\mathcal{A}},{\mathcal{C}}caligraphic_A , caligraphic_C is that R𝒜subscript𝑅𝒜R_{{\mathcal{A}}}italic_R start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT is a large-scale linear equation with nice sparse structure while R𝒜^subscript𝑅^𝒜R_{\widehat{\mathcal{A}}}italic_R start_POSTSUBSCRIPT over^ start_ARG caligraphic_A end_ARG end_POSTSUBSCRIPT is dense due to the full weight matrix in 𝒞𝒞{\mathcal{C}}caligraphic_C.

Fixed-point iteration.

We now apply the DYS algorithm from (Davis & Yin, 2017) to 0𝒜(z)+(z)+𝒞(z)0𝒜𝑧𝑧𝒞𝑧0\in{\mathcal{A}}(z)+{\mathcal{B}}(z)+{\mathcal{C}}(z)0 ∈ caligraphic_A ( italic_z ) + caligraphic_B ( italic_z ) + caligraphic_C ( italic_z ), resulting in the following fixed-point iteration:

zk+1/2=R(uk)=𝐩𝐫𝐨𝐱fα(uk)uk+1/2=2zk+1/2ukzk+1=R𝒜(uk+1/2α𝒞(zk+1/2))uk+1=uk+zk+1zk+1/2superscript𝑧𝑘12subscript𝑅superscript𝑢𝑘superscriptsubscript𝐩𝐫𝐨𝐱𝑓𝛼superscript𝑢𝑘superscript𝑢𝑘122superscript𝑧𝑘12superscript𝑢𝑘superscript𝑧𝑘1subscript𝑅𝒜superscript𝑢𝑘12𝛼𝒞superscript𝑧𝑘12superscript𝑢𝑘1superscript𝑢𝑘superscript𝑧𝑘1superscript𝑧𝑘12\begin{split}z^{k+1/2}&=R_{\mathcal{B}}(u^{k})=\mathbf{prox}_{f}^{\alpha}(u^{k% })\\ u^{k+1/2}&=2z^{k+1/2}-u^{k}\\ z^{k+1}&=R_{\mathcal{A}}(u^{k+1/2}-\alpha{\mathcal{C}}(z^{k+1/2}))\\ u^{k+1}&=u^{k}+z^{k+1}-z^{k+1/2}\end{split}start_ROW start_CELL italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT end_CELL start_CELL = italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) = bold_prox start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT ( italic_u start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_u start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT end_CELL start_CELL = 2 italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT - italic_u start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_z start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT end_CELL start_CELL = italic_R start_POSTSUBSCRIPT caligraphic_A end_POSTSUBSCRIPT ( italic_u start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT - italic_α caligraphic_C ( italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL italic_u start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT end_CELL start_CELL = italic_u start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + italic_z start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT - italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT end_CELL end_ROW (27)

where the third line is a large-scale sparse linear equation of the form

[(1+α)IαV21(1+α)IαVL,L1(1+α)I][z1k+1z2k+1zLk+1]=uk+1/2+α(bzγμSSzk+1/2).matrix1𝛼𝐼𝛼subscript𝑉211𝛼𝐼missing-subexpressionmissing-subexpressionmissing-subexpression𝛼subscript𝑉𝐿𝐿11𝛼𝐼matrixsuperscriptsubscript𝑧1𝑘1superscriptsubscript𝑧2𝑘1superscriptsubscript𝑧𝐿𝑘1superscript𝑢𝑘12𝛼subscript𝑏𝑧𝛾𝜇𝑆superscript𝑆topsuperscript𝑧𝑘12\begin{bmatrix}(1+\alpha)I\\ -\alpha V_{21}&(1+\alpha)I\\ &\ddots&\ddots\\ &&-\alpha V_{L,L-1}&(1+\alpha)I\end{bmatrix}\begin{bmatrix}z_{1}^{k+1}\\ z_{2}^{k+1}\\ \vdots\\ z_{L}^{k+1}\end{bmatrix}=u^{k+1/2}+\alpha\left(b_{z}-\frac{\gamma}{\mu}SS^{% \top}z^{k+1/2}\right).[ start_ARG start_ROW start_CELL ( 1 + italic_α ) italic_I end_CELL end_ROW start_ROW start_CELL - italic_α italic_V start_POSTSUBSCRIPT 21 end_POSTSUBSCRIPT end_CELL start_CELL ( 1 + italic_α ) italic_I end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL - italic_α italic_V start_POSTSUBSCRIPT italic_L , italic_L - 1 end_POSTSUBSCRIPT end_CELL start_CELL ( 1 + italic_α ) italic_I end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_z start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] = italic_u start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT + italic_α ( italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT - divide start_ARG italic_γ end_ARG start_ARG italic_μ end_ARG italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT ) .

By introducing vk+1/2=bzγ/μSSzk+1/2superscript𝑣𝑘12subscript𝑏𝑧𝛾𝜇𝑆superscript𝑆topsuperscript𝑧𝑘12v^{k+1/2}=b_{z}-\gamma/\mu SS^{\top}z^{k+1/2}italic_v start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT = italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT - italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_z start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT, we have

z0k+1=0,zlk+1=α1+α(Vl,l1zl1k+1+vlk+1/2)+11+αulk+1/2,l=1,,L.formulae-sequencesuperscriptsubscript𝑧0𝑘10formulae-sequencesuperscriptsubscript𝑧𝑙𝑘1𝛼1𝛼subscript𝑉𝑙𝑙1superscriptsubscript𝑧𝑙1𝑘1superscriptsubscript𝑣𝑙𝑘1211𝛼superscriptsubscript𝑢𝑙𝑘12𝑙1𝐿z_{0}^{k+1}=0,\quad z_{l}^{k+1}=\frac{\alpha}{1+\alpha}\left(V_{l,l-1}z_{l-1}^% {k+1}+v_{l}^{k+1/2}\right)+\frac{1}{1+\alpha}u_{l}^{k+1/2},\quad l=1,\ldots,L.italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = 0 , italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = divide start_ARG italic_α end_ARG start_ARG 1 + italic_α end_ARG ( italic_V start_POSTSUBSCRIPT italic_l , italic_l - 1 end_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT + italic_v start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 1 + italic_α end_ARG italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k + 1 / 2 end_POSTSUPERSCRIPT , italic_l = 1 , … , italic_L . (28)

Convergence range for the hyper-parameter α𝛼\alphaitalic_α.

From the previous paragraph, we know that (11) is equivalent to the FPI (27). From Theorem 1.1 of (Davis & Yin, 2017), we have that (27) converges for any α(0,2β)𝛼02𝛽\alpha\in(0,2\beta)italic_α ∈ ( 0 , 2 italic_β ) with β𝛽\betaitalic_β as the inverse-monotone bound of 𝒞𝒞{\mathcal{C}}caligraphic_C. From Lemma A.1 we have 2ISSsucceeds-or-equals2𝐼superscript𝑆top𝑆2I\succeq S^{\top}S2 italic_I ⪰ italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_S and

2γμSSγμS(SS)S=μγ(γ/μSS)2=2β(γ/μSS)2succeeds-or-equals2𝛾𝜇𝑆superscript𝑆top𝛾𝜇𝑆superscript𝑆top𝑆superscript𝑆top𝜇𝛾superscript𝛾𝜇𝑆superscript𝑆top22𝛽superscript𝛾𝜇𝑆superscript𝑆top2\frac{2\gamma}{\mu}SS^{\top}\succeq\frac{\gamma}{\mu}S(S^{\top}S)S^{\top}=% \frac{\mu}{\gamma}(\gamma/\mu SS^{\top})^{2}=2\beta(\gamma/\mu SS^{\top})^{2}divide start_ARG 2 italic_γ end_ARG start_ARG italic_μ end_ARG italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ divide start_ARG italic_γ end_ARG start_ARG italic_μ end_ARG italic_S ( italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_S ) italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = divide start_ARG italic_μ end_ARG start_ARG italic_γ end_ARG ( italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = 2 italic_β ( italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

i.e., 𝒞(z)𝒞𝑧{\mathcal{C}}(z)caligraphic_C ( italic_z ) is inverse monotone with β=μ/(2γ)𝛽𝜇2𝛾\beta=\mu/(2\gamma)italic_β = italic_μ / ( 2 italic_γ ). Therefore, (11) converges for any α(0,μ/γ)𝛼0𝜇𝛾\alpha\in(0,\mu/\gamma)italic_α ∈ ( 0 , italic_μ / italic_γ ). Since γ=νμ𝛾𝜈𝜇\gamma=\nu-\muitalic_γ = italic_ν - italic_μ and τ=ν/μ𝜏𝜈𝜇\tau=\nu/\muitalic_τ = italic_ν / italic_μ, we then obtain the convergence range in term of model distortion τ𝜏\tauitalic_τ, i.e., α(0,1/(τ1))𝛼01𝜏1\alpha\in(0,1/(\tau-1))italic_α ∈ ( 0 , 1 / ( italic_τ - 1 ) ). Larger α𝛼\alphaitalic_α often implies faster convergence rate, see Figure 9.

Refer to caption
Figure 9: Solver comparison for computing the inverse of random μ𝜇\muitalic_μ-monotone and ν𝜈\nuitalic_ν-Lipschitz layers with different input and hidden unit dimensions. We obverse that DYS (27) converges faster for lager α𝛼\alphaitalic_α. If α𝛼\alphaitalic_α is close to the bound μ/γ𝜇𝛾\mu/\gammaitalic_μ / italic_γ with γ=νμ𝛾𝜈𝜇\gamma=\nu-\muitalic_γ = italic_ν - italic_μ, DYS converges much faster rate than FSM (15) with hyper-parameter α=μ/ν2𝛼𝜇superscript𝜈2\alpha=\mu/\nu^{2}italic_α = italic_μ / italic_ν start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, which achieves its best convergence rate (Ryu & Boyd, 2016).

Appendix C Proofs

C.1 Proof of Theorem 3.2

We consider the neural network :xy~:𝑥~𝑦{\mathcal{H}}:x\rightarrow\tilde{y}caligraphic_H : italic_x → over~ start_ARG italic_y end_ARG defined by

v=Wz+Ux+b,z=σ(v),y~=Yz+by.formulae-sequence𝑣𝑊𝑧𝑈𝑥𝑏formulae-sequence𝑧𝜎𝑣~𝑦𝑌𝑧subscript𝑏𝑦v=Wz+Ux+b,\quad z=\sigma(v),\quad\tilde{y}=Yz+b_{y}.italic_v = italic_W italic_z + italic_U italic_x + italic_b , italic_z = italic_σ ( italic_v ) , over~ start_ARG italic_y end_ARG = italic_Y italic_z + italic_b start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT . (29)

Since (x)=μx+(x)𝑥𝜇𝑥𝑥{\mathcal{F}}(x)=\mu x+{\mathcal{H}}(x)caligraphic_F ( italic_x ) = italic_μ italic_x + caligraphic_H ( italic_x ), then {\mathcal{F}}caligraphic_F is μ𝜇\muitalic_μ-strongly monotone and ν𝜈\nuitalic_ν-Lipschitz if {\mathcal{H}}caligraphic_H is monotone and γ𝛾\gammaitalic_γ-Lipschitz with γ=νμ𝛾𝜈𝜇\gamma=\nu-\muitalic_γ = italic_ν - italic_μ.

For any pair of solutions s1=(x1,v1,z1,y~1)subscript𝑠1subscript𝑥1subscript𝑣1subscript𝑧1subscript~𝑦1s_{1}=(x_{1},v_{1},z_{1},\tilde{y}_{1})italic_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_y end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and s2=(x2,v2,z2,y~2)subscript𝑠2subscript𝑥2subscript𝑣2subscript𝑧2subscript~𝑦2s_{2}=(x_{2},v_{2},z_{2},\tilde{y}_{2})italic_s start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , over~ start_ARG italic_y end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), their difference Δs=s1s2Δ𝑠subscript𝑠1subscript𝑠2\Delta s=s_{1}-s_{2}roman_Δ italic_s = italic_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_s start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT satisfies

Δv=WΔz+UΔx,Δz=Jσ(v1,v2)Δv,Δy~=YΔzformulae-sequenceΔ𝑣𝑊Δ𝑧𝑈Δ𝑥formulae-sequenceΔ𝑧subscript𝐽𝜎subscript𝑣1subscript𝑣2Δ𝑣Δ~𝑦𝑌Δ𝑧\Delta v=W\Delta z+U\Delta x,\quad\Delta z=J_{\sigma}(v_{1},v_{2})\Delta v,% \quad\Delta\tilde{y}=Y\Delta zroman_Δ italic_v = italic_W roman_Δ italic_z + italic_U roman_Δ italic_x , roman_Δ italic_z = italic_J start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ( italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) roman_Δ italic_v , roman_Δ over~ start_ARG italic_y end_ARG = italic_Y roman_Δ italic_z (30)

where Jσsubscript𝐽𝜎J_{\sigma}italic_J start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT is a diagonal matrix with [Jσ]ii[0,1]subscriptdelimited-[]subscript𝐽𝜎𝑖𝑖01[J_{\sigma}]_{ii}\in[0,1][ italic_J start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT ∈ [ 0 , 1 ] since σ𝜎\sigmaitalic_σ is an elementwise activation with slope restricted in [0,1]01[0,1][ 0 , 1 ]. For any Λ𝔻+mΛsuperscriptsubscript𝔻𝑚\Lambda\in\mathbb{D}_{+}^{m}roman_Λ ∈ blackboard_D start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT we have

ΔvΔz,ΛΔz=Δv(IJσ)ΛJσΔv0,Δvm.formulae-sequenceΔ𝑣Δ𝑧ΛΔ𝑧Δsuperscript𝑣top𝐼subscript𝐽𝜎Λsubscript𝐽𝜎subscriptΔ𝑣0for-allΔ𝑣superscript𝑚\left\langle{\Delta v-\Delta z},{\Lambda\Delta z}\right\rangle=\Delta v^{\top}% (I-J_{\sigma})\Lambda J_{\sigma}\Delta_{v}\geq 0,\quad\forall\Delta v\in% \mathbb{R}^{m}.⟨ roman_Δ italic_v - roman_Δ italic_z , roman_Λ roman_Δ italic_z ⟩ = roman_Δ italic_v start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_I - italic_J start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT ) roman_Λ italic_J start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ≥ 0 , ∀ roman_Δ italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT . (31)

Based on (30), (31) and Condition (3) we have

Δx,Δy~ΔvΔz,ΛΔz=Δx,YΔz(WI)Δz+UΔx,ΛΔz=Δx,YΔzΔx,UΛΔz+(IW)Δz,ΛΔz=12Δz(Λ(IW)+(IW)Λ)ΔzYΔz20,Δ𝑥Δ~𝑦Δ𝑣Δ𝑧ΛΔ𝑧Δ𝑥𝑌Δ𝑧𝑊𝐼Δ𝑧𝑈Δ𝑥ΛΔ𝑧Δ𝑥𝑌Δ𝑧Δ𝑥superscript𝑈topΛΔ𝑧𝐼𝑊Δ𝑧ΛΔ𝑧12Δsuperscript𝑧topΛ𝐼𝑊𝐼superscript𝑊topΛΔ𝑧superscriptdelimited-∥∥𝑌Δ𝑧20\begin{split}\left\langle{\Delta x},{\Delta\tilde{y}}\right\rangle-\left% \langle{\Delta v-\Delta z},{\Lambda\Delta z}\right\rangle=&\left\langle{\Delta x% },{Y\Delta z}\right\rangle-\left\langle{(W-I)\Delta z+U\Delta x},{\Lambda% \Delta z}\right\rangle\\ =&\left\langle{\Delta x},{Y\Delta z}\right\rangle-\left\langle{\Delta x},{U^{% \top}\Lambda\Delta z}\right\rangle+\left\langle{(I-W)\Delta z},{\Lambda\Delta z% }\right\rangle\\ =&\frac{1}{2}\Delta z^{\top}\left(\Lambda(I-W)+(I-W^{\top})\Lambda\right)% \Delta z\geq\|Y\Delta z\|^{2}\geq 0,\end{split}start_ROW start_CELL ⟨ roman_Δ italic_x , roman_Δ over~ start_ARG italic_y end_ARG ⟩ - ⟨ roman_Δ italic_v - roman_Δ italic_z , roman_Λ roman_Δ italic_z ⟩ = end_CELL start_CELL ⟨ roman_Δ italic_x , italic_Y roman_Δ italic_z ⟩ - ⟨ ( italic_W - italic_I ) roman_Δ italic_z + italic_U roman_Δ italic_x , roman_Λ roman_Δ italic_z ⟩ end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL ⟨ roman_Δ italic_x , italic_Y roman_Δ italic_z ⟩ - ⟨ roman_Δ italic_x , italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ roman_Δ italic_z ⟩ + ⟨ ( italic_I - italic_W ) roman_Δ italic_z , roman_Λ roman_Δ italic_z ⟩ end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_Δ italic_z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( roman_Λ ( italic_I - italic_W ) + ( italic_I - italic_W start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) roman_Λ ) roman_Δ italic_z ≥ ∥ italic_Y roman_Δ italic_z ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ 0 , end_CELL end_ROW

which further implies Δx,ΔyμΔx2ΔvΔz,ΛΔz0Δ𝑥Δ𝑦𝜇superscriptnormΔ𝑥2Δ𝑣Δ𝑧ΛΔ𝑧0\left\langle{\Delta x},{\Delta y}\right\rangle-\mu\|\Delta x\|^{2}\geq\left% \langle{\Delta v-\Delta z},{\Lambda\Delta z}\right\rangle\geq 0⟨ roman_Δ italic_x , roman_Δ italic_y ⟩ - italic_μ ∥ roman_Δ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ ⟨ roman_Δ italic_v - roman_Δ italic_z , roman_Λ roman_Δ italic_z ⟩ ≥ 0. Thus, {\mathcal{H}}caligraphic_H is monotone. We can use the similar technique to derive the Lipschitz bound of {\mathcal{H}}caligraphic_H. Firstly we have

γΔx21γΔy~22ΔvΔz,ΛΔz=γΔx21γΔy~2+2(IW)Δz,ΛΔz2UΔx,ΛΔz=γΔx22Δx,Δy~1γΔy~2+Δz(2ΛΛWWΛ)ΔzγΔx22Δx,Δy~1γΔy~2+2γYΔz2=γΔx1γΔy~2.𝛾superscriptdelimited-∥∥Δ𝑥21𝛾superscriptdelimited-∥∥Δ~𝑦22Δ𝑣Δ𝑧ΛΔ𝑧𝛾superscriptdelimited-∥∥Δ𝑥21𝛾superscriptdelimited-∥∥Δ~𝑦22𝐼𝑊Δ𝑧ΛΔ𝑧2𝑈Δ𝑥ΛΔ𝑧𝛾superscriptdelimited-∥∥Δ𝑥22Δ𝑥Δ~𝑦1𝛾superscriptdelimited-∥∥Δ~𝑦2Δsuperscript𝑧top2ΛΛ𝑊superscript𝑊topΛΔ𝑧𝛾superscriptdelimited-∥∥Δ𝑥22Δ𝑥Δ~𝑦1𝛾superscriptdelimited-∥∥Δ~𝑦22𝛾superscriptdelimited-∥∥𝑌Δ𝑧2superscriptdelimited-∥∥𝛾Δ𝑥1𝛾Δ~𝑦2\begin{split}\gamma\|\Delta x\|^{2}-\frac{1}{\gamma}\|\Delta\tilde{y}\|^{2}-2% \left\langle{\Delta v-\Delta z},{\Lambda\Delta z}\right\rangle=&\gamma\|\Delta x% \|^{2}-\frac{1}{\gamma}\|\Delta\tilde{y}\|^{2}+2\left\langle{(I-W)\Delta z},{% \Lambda\Delta z}\right\rangle-2\left\langle{U\Delta x},{\Lambda\Delta z}\right% \rangle\\ =&\gamma\|\Delta x\|^{2}-2\left\langle{\Delta x},{\Delta\tilde{y}}\right% \rangle-\frac{1}{\gamma}\|\Delta\tilde{y}\|^{2}+\Delta z^{\top}(2\Lambda-% \Lambda W-W^{\top}\Lambda)\Delta z\\ \geq&\gamma\|\Delta x\|^{2}-2\left\langle{\Delta x},{\Delta\tilde{y}}\right% \rangle-\frac{1}{\gamma}\|\Delta\tilde{y}\|^{2}+\frac{2}{\gamma}\|Y\Delta z\|^% {2}=\left\|\sqrt{\gamma}\Delta x-\frac{1}{\sqrt{\gamma}}\Delta\tilde{y}\right% \|^{2}.\end{split}start_ROW start_CELL italic_γ ∥ roman_Δ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_γ end_ARG ∥ roman_Δ over~ start_ARG italic_y end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 ⟨ roman_Δ italic_v - roman_Δ italic_z , roman_Λ roman_Δ italic_z ⟩ = end_CELL start_CELL italic_γ ∥ roman_Δ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_γ end_ARG ∥ roman_Δ over~ start_ARG italic_y end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 ⟨ ( italic_I - italic_W ) roman_Δ italic_z , roman_Λ roman_Δ italic_z ⟩ - 2 ⟨ italic_U roman_Δ italic_x , roman_Λ roman_Δ italic_z ⟩ end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL italic_γ ∥ roman_Δ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 ⟨ roman_Δ italic_x , roman_Δ over~ start_ARG italic_y end_ARG ⟩ - divide start_ARG 1 end_ARG start_ARG italic_γ end_ARG ∥ roman_Δ over~ start_ARG italic_y end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + roman_Δ italic_z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( 2 roman_Λ - roman_Λ italic_W - italic_W start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ ) roman_Δ italic_z end_CELL end_ROW start_ROW start_CELL ≥ end_CELL start_CELL italic_γ ∥ roman_Δ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 ⟨ roman_Δ italic_x , roman_Δ over~ start_ARG italic_y end_ARG ⟩ - divide start_ARG 1 end_ARG start_ARG italic_γ end_ARG ∥ roman_Δ over~ start_ARG italic_y end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 2 end_ARG start_ARG italic_γ end_ARG ∥ italic_Y roman_Δ italic_z ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = ∥ square-root start_ARG italic_γ end_ARG roman_Δ italic_x - divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_γ end_ARG end_ARG roman_Δ over~ start_ARG italic_y end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . end_CELL end_ROW

Due to (31) we can further obtain γ2Δx2Δy~2superscript𝛾2superscriptnormΔ𝑥2superscriptnormΔ~𝑦2\gamma^{2}\|\Delta x\|^{2}\geq\|\Delta\tilde{y}\|^{2}italic_γ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ roman_Δ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ ∥ roman_Δ over~ start_ARG italic_y end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, i.e., {\mathcal{H}}caligraphic_H is γ𝛾\gammaitalic_γ-Lipschitz.

C.2 Proof of Proposition 3.5

Sufficient part: (5) \Rightarrow (3). From (17) we have that Y=UΛ𝑌superscript𝑈topΛY=U^{\top}\Lambdaitalic_Y = italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ. We check the inequality part of (3) as follows:

2ΛWΛΛW=ΨPPΨ=XXXQQX=2γYY.2Λsuperscript𝑊topΛΛ𝑊Ψ𝑃superscript𝑃topΨ𝑋superscript𝑋topsucceeds-or-equals𝑋𝑄superscript𝑄topsuperscript𝑋top2𝛾superscript𝑌top𝑌\begin{split}2\Lambda-W^{\top}\Lambda-\Lambda W=\Psi PP^{\top}\Psi=XX^{\top}% \succeq XQQ^{\top}X^{\top}=\frac{2}{\gamma}Y^{\top}Y.\end{split}start_ROW start_CELL 2 roman_Λ - italic_W start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Λ - roman_Λ italic_W = roman_Ψ italic_P italic_P start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ψ = italic_X italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ italic_X italic_Q italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = divide start_ARG 2 end_ARG start_ARG italic_γ end_ARG italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_Y . end_CELL end_ROW

Necessary part: (3) \Rightarrow (5). Since H0succeeds-or-equals𝐻0H\succeq 0italic_H ⪰ 0 has band structure, then it can be decomposed into H=XX𝐻𝑋superscript𝑋topH=XX^{\top}italic_H = italic_X italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT where X𝑋Xitalic_X has the following block lower triangular structure (Davis, 2006):

X=[X11X21X22XL,L1XLL].𝑋matrixsubscript𝑋11subscript𝑋21subscript𝑋22missing-subexpressionmissing-subexpressionmissing-subexpressionsubscript𝑋𝐿𝐿1subscript𝑋𝐿𝐿X=\begin{bmatrix}X_{11}\\ X_{21}&X_{22}\\ &\ddots&\ddots\\ &&X_{L,L-1}&X_{LL}\end{bmatrix}.italic_X = [ start_ARG start_ROW start_CELL italic_X start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_X start_POSTSUBSCRIPT 21 end_POSTSUBSCRIPT end_CELL start_CELL italic_X start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL italic_X start_POSTSUBSCRIPT italic_L , italic_L - 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_X start_POSTSUBSCRIPT italic_L italic_L end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] .

For this special case, a way to construct X𝑋Xitalic_X from Λ,WΛ𝑊\Lambda,Wroman_Λ , italic_W and further computation of the free parameters d,Fka,Fkb,Fq,F𝑑superscriptsubscript𝐹𝑘𝑎superscriptsubscript𝐹𝑘𝑏superscript𝐹𝑞superscript𝐹d,F_{k}^{a},F_{k}^{b},F^{q},F^{\star}italic_d , italic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a end_POSTSUPERSCRIPT , italic_F start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT , italic_F start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT , italic_F start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT can be found in (Wang & Manchester, 2023). Finally, we need to show that XX2/γYYsucceeds-or-equals𝑋superscript𝑋top2𝛾superscript𝑌top𝑌XX^{\top}\succeq 2/\gamma Y^{\top}Yitalic_X italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 2 / italic_γ italic_Y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_Y is equivalent to Y=γ/2QX𝑌𝛾2superscript𝑄topsuperscript𝑋topY=\sqrt{\gamma/2}Q^{\top}X^{\top}italic_Y = square-root start_ARG italic_γ / 2 end_ARG italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT for some QQIprecedes-or-equals𝑄superscript𝑄top𝐼QQ^{\top}\preceq Iitalic_Q italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪯ italic_I, which can be directly followed by Lemma 3 of (Rantzer, 1996).

C.3 Proof of Proposition 4.1

From Lemma A.1 we have

2I(Vγ/μSS)(Vγ/μSS)=2IVV+2γ/μSS2γ/μSS0.2𝐼𝑉𝛾𝜇𝑆superscript𝑆topsuperscript𝑉𝛾𝜇𝑆superscript𝑆toptop2𝐼𝑉superscript𝑉top2𝛾𝜇𝑆superscript𝑆topsucceeds-or-equals2𝛾𝜇𝑆superscript𝑆topsucceeds-or-equals02I-(V-\gamma/\mu SS^{\top})-(V-\gamma/\mu SS^{\top})^{\top}=2I-V-V^{\top}+2% \gamma/\mu SS^{\top}\succeq 2\gamma/\mu SS^{\top}\succeq 0.2 italic_I - ( italic_V - italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) - ( italic_V - italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = 2 italic_I - italic_V - italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + 2 italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 2 italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 0 . (32)

Then, the equilibrium network (10) is well-posed by Theorem 1 of (Revay et al., 2020).

C.4 Proof of Proposition 4.2

We first show that 0𝒜(z)+(z)+𝒞(z)0𝒜𝑧𝑧𝒞𝑧0\in{\mathcal{A}}(z)+{\mathcal{B}}(z)+{\mathcal{C}}(z)0 ∈ caligraphic_A ( italic_z ) + caligraphic_B ( italic_z ) + caligraphic_C ( italic_z ) is a monotone operator splitting problem. It is obvious that ,𝒞𝒞{\mathcal{B}},{\mathcal{C}}caligraphic_B , caligraphic_C are maximal monotone operators. From Lemma A.1 we have (IV)+(IV)0succeeds-or-equals𝐼𝑉superscript𝐼𝑉top0(I-V)+(I-V)^{\top}\succeq 0( italic_I - italic_V ) + ( italic_I - italic_V ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ 0, i.e. 𝒜𝒜{\mathcal{A}}caligraphic_A is also monotone. Then, we show that the above operator splitting problem shares the same set of equilibrium points with the model inverse (10). First, we rewrite it into a two-operator splitting problem 0𝒜^(z)+(z)0^𝒜𝑧𝑧0\in\widehat{{\mathcal{A}}}(z)+{\mathcal{B}}(z)0 ∈ over^ start_ARG caligraphic_A end_ARG ( italic_z ) + caligraphic_B ( italic_z ) where 𝒜^=𝒜+𝒞^𝒜𝒜𝒞\widehat{{\mathcal{A}}}={\mathcal{A}}+{\mathcal{C}}over^ start_ARG caligraphic_A end_ARG = caligraphic_A + caligraphic_C. By applying the forward-backward splitting with parameter α=1𝛼1\alpha=1italic_α = 1, we can compute the solution z𝑧zitalic_z via the following iteration:

zk+1=R(zk𝒜^(zk))=𝐩𝐫𝐨𝐱f1(zk(IV+γ/μSS)zk+bz)=σ^((Vγ/μSS)zk+bz).superscript𝑧𝑘1subscript𝑅superscript𝑧𝑘^𝒜superscript𝑧𝑘superscriptsubscript𝐩𝐫𝐨𝐱𝑓1superscript𝑧𝑘𝐼𝑉𝛾𝜇𝑆superscript𝑆topsuperscript𝑧𝑘subscript𝑏𝑧^𝜎𝑉𝛾𝜇𝑆superscript𝑆topsuperscript𝑧𝑘subscript𝑏𝑧\begin{split}z^{k+1}=&R_{{\mathcal{B}}}(z^{k}-\widehat{{\mathcal{A}}}\bigl{(}z% ^{k}\bigr{)})=\mathbf{prox}_{f}^{1}\left(z^{k}-\left(I-V+\gamma/\mu SS^{\top}% \right)z^{k}+b_{z}\right)=\hat{\sigma}\left(\left(V-\gamma/\mu SS^{\top}\right% )z^{k}+b_{z}\right).\end{split}start_ROW start_CELL italic_z start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = end_CELL start_CELL italic_R start_POSTSUBSCRIPT caligraphic_B end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - over^ start_ARG caligraphic_A end_ARG ( italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ) = bold_prox start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - ( italic_I - italic_V + italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) = over^ start_ARG italic_σ end_ARG ( ( italic_V - italic_γ / italic_μ italic_S italic_S start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_z start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + italic_b start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) . end_CELL end_ROW

Thus, any solution zsuperscript𝑧z^{\star}italic_z start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT of the equilibrium network (10) is also an equilibrium point of the above iteration.

C.5 Proof of Proposition 5.1

First, we have f(x)=G(x)𝒢(x)𝑓𝑥superscript𝐺top𝑥𝒢𝑥\nabla f(x)=G^{\top}(x){\mathcal{G}}(x)∇ italic_f ( italic_x ) = italic_G start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_x ) caligraphic_G ( italic_x ) where G(x)=𝒢(x)𝐺𝑥𝒢𝑥G(x)=\nabla{\mathcal{G}}(x)italic_G ( italic_x ) = ∇ caligraphic_G ( italic_x ) satisfies G(x)μnorm𝐺𝑥𝜇\|G(x)\|\geq\mu∥ italic_G ( italic_x ) ∥ ≥ italic_μ. Then, the PL inequality holds for f𝑓fitalic_f with m=μ2𝑚superscript𝜇2m=\mu^{2}italic_m = italic_μ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, i.e.,

12f(x)2=12𝒢(x)G(x)G(x)𝒢(x)μ22𝒢(x)2=μ2(f(x)f).12superscriptnorm𝑓𝑥212𝒢superscript𝑥top𝐺superscript𝑥top𝐺𝑥𝒢𝑥superscript𝜇22superscriptnorm𝒢𝑥2superscript𝜇2𝑓𝑥superscript𝑓\frac{1}{2}\|\nabla f(x)\|^{2}=\frac{1}{2}{\mathcal{G}}(x)^{\top}G(x)^{\top}G(% x){\mathcal{G}}(x)\geq\frac{\mu^{2}}{2}\|{\mathcal{G}}(x)\|^{2}=\mu^{2}(f(x)-f% ^{\star}).divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ ∇ italic_f ( italic_x ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG caligraphic_G ( italic_x ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G ( italic_x ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_G ( italic_x ) caligraphic_G ( italic_x ) ≥ divide start_ARG italic_μ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∥ caligraphic_G ( italic_x ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_μ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_f ( italic_x ) - italic_f start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) . (33)

Appendix D Experiments

D.1 Training details

We choose ReLU as our default activation and use ADAM (Kingma & Ba, 2015) with one-cycle linear learning rate (Coleman et al., 2017) except the NGP case which SGD with piecewise constant scheduling. For the NGP case, we use the cross entropy loss while the L2 loss is used for the rest of the examples. We found that it can improve the model training by enforcing QQ=Isuperscript𝑄top𝑄𝐼Q^{\top}Q=Iitalic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_Q = italic_I, which can be done by fixing Fp=0superscript𝐹𝑝0F^{p}=0italic_F start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT = 0. Dataset and model architectures are described as follows.

1D Step function.

The target function is a step function

f(x)={2,x>02,x<0𝑓𝑥cases2𝑥02𝑥0f(x)=\begin{cases}2,&x>0\\ -2,&x<0\end{cases}italic_f ( italic_x ) = { start_ROW start_CELL 2 , end_CELL start_CELL italic_x > 0 end_CELL end_ROW start_ROW start_CELL - 2 , end_CELL start_CELL italic_x < 0 end_CELL end_ROW

which is monotone and 00-Lipschitz everywhere except the singularity point x=0𝑥0x=0italic_x = 0. We try to fit this curve with (0.1,10)0.110(0.1,10)( 0.1 , 10 )-Lipschitz models. The optimal fit is a linear piecewise continuous function with slope of 10 near x=0𝑥0x=0italic_x = 0 and slope of 0.1 near x=±2𝑥plus-or-minus2x=\pm 2italic_x = ± 2. We take 1000 random samples from [2,2]22[-2,2][ - 2 , 2 ] for training. Our model (BiLipNet) is an one-layer residual network (x)=μx+(x)𝑥𝜇𝑥𝑥{\mathcal{F}}(x)=\mu x+{\mathcal{H}}(x)caligraphic_F ( italic_x ) = italic_μ italic_x + caligraphic_H ( italic_x ) where {\mathcal{H}}caligraphic_H has 8 hidden layers of width 32, giving the model 15.8K parameters. We compare to i-ResNet (Chen et al., 2019) and i-DenseNet (Perugachi-Diaz et al., 2021), where the nonlinear block {\mathcal{H}}caligraphic_H has 2 and 4 hidden layers, respectively. For those two models, we test for depth from 2 to 8 with proper hidden width (so that they has similar amount of parameters). And the empirical Lipschitz bound is computed via finite difference over the test data. As shown in Figure 1, our model achieves much tighter bounds than other models.

Neural Gaussian process.

We take 1000 two-moon data points as training data and 1000 Gaussian samples with mean (1.3,1.8)1.31.8(1.3,-1.8)( 1.3 , - 1.8 ) and variance (0.02,0.01)0.020.01(0.02,0.01)( 0.02 , 0.01 ) as OOD data. For all models, we use fixed input weight to mapping the 2D input into 128D hidden space, then perform hidden space transformation using bi-Lipschitz models, and finally add a Gaussian process as the output layer. SNGP uses 3 residual layer x+(x)𝑥𝑥x+{\mathcal{H}}(x)italic_x + caligraphic_H ( italic_x ) where the Lipschitz bound of {\mathcal{H}}caligraphic_H is c<1𝑐1c<1italic_c < 1. BiLipNet has one monotone and Lipschitz layer with two orthogonal layer, i.e., K=1𝐾1K=1italic_K = 1 for (8). The nonlinear block {\mathcal{H}}caligraphic_H of our model has 6 hidden layers with width of 32. Both models are chosen to have the same amount of parameters, roughly 233K233𝐾233K233 italic_K.

CIFAR-10/100 datasets.

We first adopt the SNGP model from (Liu et al., 2020) and make some modifications as follows.

  • SNGP contains three bi-Lipschitz components with each including four residual layers of the form x+(x)𝑥𝑥x+{\mathcal{H}}(x)italic_x + caligraphic_H ( italic_x ). It used spectral norm bound c=6𝑐6c=6italic_c = 6 for the weights inside {\mathcal{H}}caligraphic_H, which means that the bi-Lipschitz property may not hold. To provide a certified guarantee of bi-Lipschitzness we need c(0,1)𝑐01c\in(0,1)italic_c ∈ ( 0 , 1 ). We tried three values of c𝑐citalic_c: 0.35, 0.65 and 0.95. Since the Lipschitz bounds are μ=(1c)4𝜇superscript1𝑐4\mu=(1-c)^{4}italic_μ = ( 1 - italic_c ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT and ν=(1+c)4𝜈superscript1𝑐4\nu=(1+c)^{4}italic_ν = ( 1 + italic_c ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, a larger c𝑐citalic_c implies a more expressive SNGP model.

  • We ran the SNGP with/without batch normalization for the bi-Lipschitz components. As pointed out in (Liu et al., 2023), the batch normalization may re-scale a layer’s spectral norm in unexpected ways. So there is no theoretical guarantee on bi-Lipschitz property when batch normalization is applied.

  • Training the original SNGP takes about 95% GPU memory of an Nvidia RTX3090. With the same number of parameters, our model needs more GPU memory as it uses the approach from (Trockman & Kolter, 2021) to perform the Cayley transform of convolution operators, which involves FFT and inverse FFT. In order to use a single GPU to train both models, we reduce the width of SNGP so that it has a similar amount of parameters as our model (14similar-toabsent14\sim 14∼ 14M).

Our model has a similar structure to SNGP except that we replace their bi-Lipschitz components with our proposed bi-Lipschitz networks. Note that there is no batch normalization inside our bi-Lipschitz networks. All models are trained for 200 epochs using the mini-batch stochastic gradient descent (SGD) method with batch size of 256. We adjust the learning rate based on a piecewise constant schedule.

2D Rosenbrock function.

The true function is a Rosenbrock function defined by

r(x,y)=1200(x1)2+12(yx2)2.𝑟𝑥𝑦1200superscript𝑥1212superscript𝑦superscript𝑥22r(x,y)=\frac{1}{200}(x-1)^{2}+\frac{1}{2}\bigl{(}y-x^{2}\bigr{)}^{2}.italic_r ( italic_x , italic_y ) = divide start_ARG 1 end_ARG start_ARG 200 end_ARG ( italic_x - 1 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT .

Note that we use a scaling factor of 1/20012001/2001 / 200 for the classic Rosenbrock function. The above function is non-convex but has one minimum at (1,1)11(1,1)( 1 , 1 ). We also consider the combination of the above Rosenbrock function with the following 2D Sine function:

s(x,y)=0.25(sin(8(x1)π/2)+sin(8(y1)π/2)+2).𝑠𝑥𝑦0.258𝑥1𝜋28𝑦1𝜋22s(x,y)=0.25(\sin(8(x-1)-\pi/2)+\sin(8(y-1)-\pi/2)+2).italic_s ( italic_x , italic_y ) = 0.25 ( roman_sin ( 8 ( italic_x - 1 ) - italic_π / 2 ) + roman_sin ( 8 ( italic_y - 1 ) - italic_π / 2 ) + 2 ) .

In this case r(x,y)+s(x,y)𝑟𝑥𝑦𝑠𝑥𝑦r(x,y)+s(x,y)italic_r ( italic_x , italic_y ) + italic_s ( italic_x , italic_y ) still has a unique global minimum at (1,1)11(1,1)( 1 , 1 ). But there are many local minima. We take 5K random training samples from the domain [2,2]×[1,3]2213[-2,-2]\times[-1,3][ - 2 , - 2 ] × [ - 1 , 3 ]. The proposed BiLipNet contains two monotone and Lipschitz layers (i.e., K=2𝐾2K=2italic_K = 2 for (8)). The nonlinear block {\mathcal{H}}caligraphic_H has 4 hidden layers of width 128. The model size is roughly 16K. The ICNN model has 8 hidden layers with width of 180. The MLP has hidden units of [128,256,256,512]128256256512[128,256,256,512][ 128 , 256 , 256 , 512 ]. We trained i-ResNet and i-DenseNet with different depth and width such that the total amount of parameters is comparable with BiLipNet.

Parametric Rosenbrock function.

We consider the following parametric Rosenbrock function

r(x,y;p)=1200(xa)2+12(ybx2)2,p=(a,b)[1,1]2.formulae-sequence𝑟𝑥𝑦𝑝1200superscript𝑥𝑎212superscript𝑦𝑏superscript𝑥22𝑝𝑎𝑏superscript112r(x,y;p)=\frac{1}{200}(x-a)^{2}+\frac{1}{2}(y-bx^{2})^{2},\quad p=(a,b)\in[-1,% 1]^{2}.italic_r ( italic_x , italic_y ; italic_p ) = divide start_ARG 1 end_ARG start_ARG 200 end_ARG ( italic_x - italic_a ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_b italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , italic_p = ( italic_a , italic_b ) ∈ [ - 1 , 1 ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT .

We take 10K random training data. The partially BiLipNet contains 3 orthogonal layers, and 2 monotone and Lipschitz layers (the {\mathcal{H}}caligraphic_H block of each layer has 4 hidden layer with width 128). The bias term of each orthogonal layer is produced by an MLP with hidden units of [64,128,2]641282[64,128,2][ 64 , 128 , 2 ] while the bias for those hidden units inside the {\mathcal{H}}caligraphic_H block is generated by an MLP of [64,128,256,512]64128256512[64,128,256,512][ 64 , 128 , 256 , 512 ]. The model’s bi-Lipschitz bound is chosen to be (0.04,16)0.0416(0.04,16)( 0.04 , 16 ). The resulting model size is 604K.

ND Rosenbrock function.

We also consider the N𝑁Nitalic_N-dimensional (with N=20𝑁20N=20italic_N = 20) Rosenbrock function:

R(x)=1N1i=1N1r(xi,xi+1)𝑅𝑥1𝑁1superscriptsubscript𝑖1𝑁1𝑟subscript𝑥𝑖subscript𝑥𝑖1R(x)=\frac{1}{N-1}\sum_{i=1}^{N-1}r(x_{i},x_{i+1})italic_R ( italic_x ) = divide start_ARG 1 end_ARG start_ARG italic_N - 1 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N - 1 end_POSTSUPERSCRIPT italic_r ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT )

which is non-convex and has a unique global minimum at (1,1,,1)111(1,1,\ldots,1)( 1 , 1 , … , 1 ). Besides it also has many local minima. We take 10K random samples over the domain [2,2]20superscript2220[-2,2]^{20}[ - 2 , 2 ] start_POSTSUPERSCRIPT 20 end_POSTSUPERSCRIPT and do training with batch size of 200. Note that the data size is very small compared to the dimension. We then use 500K samples for testing. BiLipNet has two monotone and Lipschitz layers (i.e., K=2𝐾2K=2italic_K = 2 for (8)) where each layer has a nonlinear block {\mathcal{H}}caligraphic_H with 8 hidden layer of width 256 (model size similar-to\sim 2.1M). For the i-ResNet/i-DenseNet, we try different depths from 2 to 10 and observe that depth of 5 yields slightly better results. The width of hidden layer is chosen so that it has a similar amount of parameters as BiLipNet.

D.2 Extra results

Some extra results for the bi-Lipschitz models on two-moon and CIFAR-10/100 datasets are shown in Figure 10 and Table 3, respectively. Figure 11 depicts the additional results on surrogate loss learning.

Accuracy (\uparrow) ECE (\downarrow) NLL (\downarrow)
Method c𝑐citalic_c Clean Corrupted Clean Corrupted Clean Corrupted
CIFAR-10
SNGP-BN 0.95 94.7 ±plus-or-minus\pm± 0.079 73.0 ±plus-or-minus\pm± 0.461 0.017 ±plus-or-minus\pm± 0.002 0.127 ±plus-or-minus\pm± 0.010 0.166 ±plus-or-minus\pm± 0.004 0.991 ±plus-or-minus\pm± 0.054
0.65 94.1 ±plus-or-minus\pm± 0.159 72.3 ±plus-or-minus\pm± 0.561 0.016 ±plus-or-minus\pm± 0.000 0.116 ±plus-or-minus\pm± 0.005 0.182 ±plus-or-minus\pm± 0.005 0.985 ±plus-or-minus\pm± 0.029
0.35 92.3 ±plus-or-minus\pm± 0.260 70.4 ±plus-or-minus\pm± 0.800 0.008 ±plus-or-minus\pm± 0.003 0.095 ±plus-or-minus\pm± 0.007 0.231 ±plus-or-minus\pm± 0.006 0.995 ±plus-or-minus\pm± 0.031
BiLipNet 0.95 86.2 ±plus-or-minus\pm± 0.250 70.8 ±plus-or-minus\pm± 0.469 0.020 ±plus-or-minus\pm± 0.003 0.052 ±plus-or-minus\pm± 0.005 0.423 ±plus-or-minus\pm± 0.006 0.895 ±plus-or-minus\pm± 0.020
0.65 86.7 ±plus-or-minus\pm± 0.129 72.8 ±plus-or-minus\pm± 0.592 0.015 ±plus-or-minus\pm± 0.005 0.047 ±plus-or-minus\pm± 0.009 0.400 ±plus-or-minus\pm± 0.006 0.830 ±plus-or-minus\pm± 0.024
0.35 84.5 ±plus-or-minus\pm± 0.184 72.6 ±plus-or-minus\pm± 0.216 0.010 ±plus-or-minus\pm± 0.002 0.052 ±plus-or-minus\pm± 0.004 0.457 ±plus-or-minus\pm± 0.002 0.827 ±plus-or-minus\pm± 0.008
CIFAR-100
SNGP-BN 0.95 72.3 ±plus-or-minus\pm± 0.513 44.8 ±plus-or-minus\pm± 0.470 0.071 ±plus-or-minus\pm± 0.006 0.091 ±plus-or-minus\pm± 0.006 1.042 ±plus-or-minus\pm± 0.018 2.476 ±plus-or-minus\pm± 0.025
0.65 67.8 ±plus-or-minus\pm± 1.006 41.5 ±plus-or-minus\pm± 0.916 0.117 ±plus-or-minus\pm± 0.007 0.092 ±plus-or-minus\pm± 0.002 1.231 ±plus-or-minus\pm± 0.035 2.573 ±plus-or-minus\pm± 0.036
0.35 61.9 ±plus-or-minus\pm± 0.741 37.0 ±plus-or-minus\pm± 0.660 0.158 ±plus-or-minus\pm± 0.006 0.098 ±plus-or-minus\pm± 0.006 1.510 ±plus-or-minus\pm± 0.029 2.760 ±plus-or-minus\pm± 0.043
BiLipNet 0.95 51.0 ±plus-or-minus\pm± 0.480 35.8 ±plus-or-minus\pm± 0.397 0.230 ±plus-or-minus\pm± 0.006 0.137 ±plus-or-minus\pm± 0.007 2.064 ±plus-or-minus\pm± 0.024 2.718 ±plus-or-minus\pm± 0.014
0.65 55.2 ±plus-or-minus\pm± 0.426 39.2 ±plus-or-minus\pm± 0.495 0.225 ±plus-or-minus\pm± 0.004 0.137 ±plus-or-minus\pm± 0.005 1.887 ±plus-or-minus\pm± 0.021 2.576 ±plus-or-minus\pm± 0.022
0.35 54.4 ±plus-or-minus\pm± 0.438 41.1 ±plus-or-minus\pm± 0.200 0.194 ±plus-or-minus\pm± 0.008 0.126 ±plus-or-minus\pm± 0.009 1.876 ±plus-or-minus\pm± 0.031 2.447 ±plus-or-minus\pm± 0.016
Table 3: Results for SNGP-BN (SNGP with batch normalization) and BiLipNet (without batch normalization) on CIFAR-10/100, averaged over 5 seeds. As pointed out by (Liu et al., 2023), the batch normalization may rescale a layer’s spectral norm in unexpected ways. So there is no theoretical guarantee on bi-Lipschitz property for SNGP-BN. This may offer it extra expressive power, leading to performance improvement in both clean and corrupted accuracy for a large distortion models (i.e. c=0.95𝑐0.95c=0.95italic_c = 0.95). For models with low distortion (i.e. c=0.35𝑐0.35c=0.35italic_c = 0.35), BiLipNet has better accuracy for the corrupted dataset.
Refer to caption
Figure 10: Uncertainty qualification via neural Gaussian process with different bi-Lipschitz bound specifications.
Refer to caption
Figure 11: Additional results for Learning a surrogate loss for the Rosenbrock and Rosenbrock + Sine functions. The first row contains the true functions while the remaining rows show learned functions and errors for various surrogate loss models. Our model (BiLipNet) has the flexibility of capturing the non-convex sub-level sets, but can also fit smoothed representations by reducing the distortion parameter.