Global-Local Regularization Via Distributional Robustness

July 25, 2023

1. Introduction

As the Wasserstein (WS) distance is a powerful and convenient tool of measuring closeness between distributions, Wasserstein Distributional Robustness (WDR) has been one of the most widely-used variants of DR. Here we consider a generic Polish space S endowed with a distribution \mathbb{P}. Let r:S\to\mathbb{R} be a real-valued (risk) function and c:S\times S\to \mathbb{R}_{+} be a cost function. Distributional robustness setting aims to find the distribution \tilde{\mathbb{P}} in the vicinity of \mathbb{P} and maximizes the risk in the expectation form [1, 2]:

    \[\sup_{\mathbb{\tilde{P}}:\mathcal{W}_{c}\left(\mathbb{P},\mathbb{\tilde{P}}\right)<\epsilon}\underset{\tilde{Z}\sim\mathbb{\mathbb{\tilde{P}}}}{\mathbb{E}}\left[r\left(\tilde{Z}\right)\right],\]

where \epsilon>0 and \mathcal{W}_{c}\left(\mathbb{P},\mathbb{\mathbb{\tilde{P}}}\right):=\inf_{\gamma\in\Gamma\left(\mathbb{P},\mathbb{\tilde{P}}\right)}\int cd\gamma denotes an optimal transport (OT) or a WS distance with the set of couplings \Gamma\left(\mathbb{P},\mathbb{\mathbb{\tilde{P}}}\right) whose marginals are \mathbb{P} and \mathbb{\tilde{P}}.

Direct optimization over the set of distributions \tilde{\mathbb{P}} is often computationally intractable except in limited cases, we thus seek to cast this problem into its dual form. With the assumption that r\in L^{1}\left(\mathbb{P}\right) is upper semi-continuous and the cost c is a non-negative and continuous function satisfying c(Z,\tilde{Z})=0\text{ iff }Z=\tilde{Z}, [1, 2] showed the dual form is:

    \[\inf_{\lambda\geq0}\left\{ \lambda\epsilon+\mathbb{E}_{Z\sim\mathbb{P}}[\sup_{\tilde{Z}}\left\{ r\left(\tilde{Z}\right)-\lambda c\left(\tilde{Z},Z\right)\right\} ]\right\}.\]

When applying DR to the supervised learning setting, \tilde{Z}=\left(\tilde{X},\tilde{Y}\right) is a pair of data/label drawn from \mathbb{\mathbb{\tilde{P}}} and r is the loss function  [1, 2]. The fact that r engages only \tilde{Z}=\left(\tilde{X},\tilde{Y}\right)\sim\mathbb{\tilde{P}} certainly restricts the modeling capacity of the above equation. The reasons are as follows. Firstly, for each anchor Z, the most challenging sample  \tilde{Z} is currently defined as the one maximizing \sup_{\tilde{Z}} \left\{r(\tilde{Z}) - \lambda c(Z, \tilde{Z}) \right\}, where r(\tilde{Z}) is inherited from the primal form. Hence, it is not suitable to express the risk function r engaging both Z and \tilde{Z} (e.g., Kullback-Leibler divergence KL\left(p\left(\tilde{Z}\right)\Vert p(Z)\right) between the predictions for Z and \tilde{Z} as in TRADES [3]). Secondly, it is also \emph{impossible} to inject a \emph{global regularization term} involving a batch of samples \tilde{Z} and Z.

Contribution. To empower the formulation of DR for efficiently tackling various real-world problems, in this work, we propose a rich OT-based DR framework, named Global-Local Regularization Via Distributional Robustness (GLOT-DR). Specifically, by designing special joint distributions \mathbb{P} and \tilde{\mathbb{P}} together with some constraints, our framework is applicable to a mixed variety of real-world applications, including domain generalization (DG),  domain adaptation (DA),  semi-supervised learning (SSL), and adversarial machine learning (AML).

Additionally, our GLOT-DR makes it possible for us to equip not only a \emph{local regularization term} for enforcing a local smoothness and robustness, but also a \emph{global regularization term} to impose a global effect targeting a downstream task. Moreover, by designing a specific WS distance, we successfully develop a closed-form solution for GLOT-DR without using the dual form.

Technically, our solution turns solving the inner maximization in the dual form into sampling a set of challenging particles according to a local distribution,  on which we can handle efficiently using Stein Variational Gradient Decent (SVGD) [4] approximate inference algorithm. Based on the general framework of GLOT-DR, we establish the settings for DG, DA, SSL, and AML and conduct experiments to compare our GLOT-DR to state-of-the-art baselines in these real-world applications. Overall, our contributions can be summarized as follows:

We enrich the general framework of DR to make it possible for many real-world applications by enforcing both local and global regularization terms. We note that the global regularization term is crucial for many downstream tasks.
We propose a closed-form solution for our GLOT-DR without involving the dual form in [1, 2]. We note that the dual form is \emph{not computationally tractable} due to the minimization over \lambda.
We conduct comprehensive experiments to compare our GLOT-DR to state-of-the-art baselines in DG, DA, SSL, and AML. The experimental results demonstrate the merits of our proposed approach and empirically prove that both of the introduced local and global regularization terms advance existing methods across various scenarios, including DG, DA, SSL, and AML.

2. Proposed framework

We propose a regularization technique based on optimal transport distributional robustness that can be widely applied to many settings including i) \emph{semi-supervised learning}, ii) \emph{domain adaptation}, iii) \emph{domain generalization}, and iv) \emph{adversarial machine learning}. In what follows, we present the general setting along with the notations used throughout the paper and technical details of our framework.


Figure 1: Overview of our proposed GLOT-DR framework

Assume that we have \emph{multiple labeled source domains} with the \emph{data/label} distributions \left\{ \mathbb{P}_{k}^{S}\right\} _{k=1}^{K} and a \emph{single unlabeled target domain} with the \emph{data} distribution \mathbb{P}^{T}. For the k-th source domain, we draw a batch of B_{k}^{S} examples as \left(X_{ki}^{S},Y_{ki}^{S}\right)\sim\mathbb{P}_{k}^{S}, where i=1,\ldots,B_{k}^{S}. Meanwhile, for the target domain, we sample a batch of B^{T} examples as X_{i}^{T}\sim\mathbb{P}^{T},\,i=1,\ldots,B^{T}.

It is worth noting that for the DG setting, we set B^{T}=0 (i.e., not use any target data in training). Furthermore, we examine the multi-class classification problem with the label set \mathcal{Y}:=\{1,...,M\}.

Hence, the prediction of a classifier is a prediction probability belonging to the m-1 \emph{label simplex}. Finally, let f_{\psi}=h_{\theta}\circ g_{\phi} with \psi=(\phi,\theta) be parameters of our deep net, wherein g_{\phi} is the feature extractor and h_{\theta} is the classifier on top of feature representations.

Constructing Challenging Samples. As explained below, our method involves the construction of a random variable Z with distribution \mathbb{P} and another random variable \tilde{Z} with distribution \mathbb{\tilde{P}}, containing anchor samples \left(X_{ki}^{S},Y_{ki}^{S}\right),X_{i}^{T} and their perturbed counterparts \left(\tilde{X}_{kij}^{S},\tilde{Y}_{kij}^{S}\right),\tilde{X}_{ij}^{T}.

The inclusion of both anchor samples and perturbed samples allows us to define a unifying cost function containing local regularization, global regularization, and classification loss.

Concretely, we first start with the construction of Z, containing repeated anchor samples as follows:

    \[Z :=\left[\left[\left[X_{kij}^{S},Y_{kij}^{S}\right]_{k=1}^{K}\right]_{i=1}^{B_{k}^{S}}\right]_{j=0}^{n^{S}},\left[\left[X_{ij}^{T}\right]_{i=1}^{B^{T}}\right]_{j=0}^{n^{T}}.\]

Here, each source sample is repeated n^{S}+1 times (X_{kij}^{S},Y_{kij}^{S})=(X_{ki}^{S},Y_{ki}^{S}),\,\forall j, while each target sample is repeated n^{T}+1 times X_{ij}^{T}=X_{i}^{T},\,\forall j. The corresponding distribution of this random variable is denoted as \mathbb{P}.  In contrast to Z, we next define random variable \tilde{Z}\sim\tilde{\mathbb{P}}, whose form is

    \[\tilde{Z}:=\left[\left[\left[\tilde{X}_{kij}^{S},\tilde{Y}_{kij}^{S}\right]_{k=1}^{K}\right]_{i=1}^{B_{k}^{S}}\right]_{j=0}^{n^{S}},\left[\left[\tilde{X}_{ij}^{T}\right]_{i=1}^{B^{T}}\right]_{j=0}^{n^{T}}.\]

Here we note that for \tilde{X}^S_{kij}, the index k specifies the k-th source domain, the index i specifies an example in the k-th source batch, while the index j specifies the j-th perturbed example to the source example X^S_{ki}. Similarly, for \tilde{X}^T_{ij}, the index i specifies an example in the target batch, while the index j specifies the j-the perturbed example to the target example X^T_i.

We would like \tilde{Z} to contain both: i) anchor examples, i.e., \left(\tilde{X}_{ki0}^{S},\tilde{Y}_{ki0}^{S}\right)=\left(X_{ki}^{S},Y_{ki}^{S}\right) and \tilde{X}_{i0}^{T}=X_{i}^{T}; ii) n^{S} perturbed source samples \left\{ \left(\tilde{X}_{kij}^{S},\tilde{Y}_{kij}^{S}\right)\right\} _{j=1}^{n^{S}} to \left(X^S_{ki},Y^S_{ki}\right) and n^{T} perturbed target samples \left\{ \tilde{X}_{ij}^{T}\right\} _{i=1}^{n^{T}} to X^T_i. In order to impose this requirement, we only consider sampling \tilde{Z} from distribution \tilde{\mathbb{P}} inside the Wasserstein-ball of \mathbb{P}, i.e., satisfying

    \[\mathcal{W}_{\rho}\left(\mathbb{P},\tilde{\mathbb{P}}\right):=\underset{\gamma\in\Gamma\left(\mathbb{P},\tilde{\mathbb{P}}\right)}{\inf}\underset{\left(Z,\tilde{Z}\right)\sim\gamma}{\mathbb{E}}\left[\rho\left(Z,\tilde{Z}\right)\right]^{\frac{1}{q}}\leq\epsilon\]

Here we slightly abuse the notion by using Y\in\mathcal{Y} to represent its corresponding one-hot vector. By definition, this cost metric almost surely: i) enforces all 0-th (i.e., j=0) samples in \tilde{Z} to be anchor samples, i.e., \tilde{X}_{ki0}^{S}=X_{ki0}= X_{ki}; ii) allows perturbations on the input data, i.e., \tilde{X}_{kij}^{S}\neq X_{ki}^{S} and \tilde{X}_{ij}^{T}\neq X_{i}^{T}, for \forall j\neq0; iii) restricts perturbations on labels, i.e., Y_{kij}^{S}=\tilde{Y}_{kij}^{S} for \forall j (see Figure 1 for the illustration). The reason is that if either (i) or (iii) is violated on a non-zero measurable set then \mathcal{W}_{\rho}\left(\mathbb{P},\tilde{\mathbb{P}}\right) becomes infinity.

Learning Robust Classifier. Upon clear definitions of \tilde{Z} and \tilde{\mathbb{P}}, we wish to learn good representations and regularize the classifier f_{\psi}, via the following DR problem:

    \[\min_{\theta,\phi}\max_{\tilde{\mathbb{P}}:\mathcal{W}_{\rho}\left(\mathbb{P},\tilde{\mathbb{P}}\right)\leq\epsilon}\mathbb{E}_{\tilde{Z}\sim\tilde{\mathbb{P}}}\left[r\left(\tilde{Z};\phi,\theta\right)\right].\]

The cost function r\left(\tilde{Z};\phi,\theta\right):=\alpha r^{l}\left(\tilde{Z};\phi,\theta\right)+\beta r^{g}\left(\tilde{Z};\phi,\theta\right)+\mathcal{L}\left(\tilde{Z};\phi,\theta\right) with \alpha,\beta>0 is defined as the weighted sum of a \emph{local-regularization function} r^{l}\left(\tilde{Z};\phi,\theta\right), a \emph{global-regularization function} r^{g}\left(\tilde{Z};\phi,\theta\right), and the \emph{loss function} \mathcal{L}\left(\tilde{Z};\phi,\theta\right), whose explicit forms are dependent on the task (DA, SSL, DG, and AML).

Intuitively, the optimization iteratively searches for the worst-case \tilde{\mathbb{P}} w.r.t. the cost r\left(\cdot;\phi,\theta\right), then changes the network f_{\psi} to minimize the worst-case cost.

Training Procedure of Our Approach In what follows, we present how to solve the above optimization efficiently. Accordingly, we first need to sample \left(X_{ki}^{S},Y_{ki}^{S}\right)_{i=1}^{B_{k}^{S}}\sim\mathbb{P}_{k}^{S},\forall k\,\text{and}\,X_{1:B^{T}}^{T}\sim\mathbb{P}^{T}. For each source anchor \left(X_{ki}^{S},Y_{ki}^{S}\right), we sample

\left[\tilde{X}_{kij}^{S}\right]_{j=1}^{n^{S}}\sim{q_{ki}^{S}} in the ball B_{\epsilon}\left(X_{ki}^{S}\right) with the \emph{density function proportional} to \exp\left\{ \lambda[\alpha s(X_{ki}^{S},\bullet;\psi)+\ell(\bullet,Y_{ki}^{S};\psi)]\right\}. Furthermore, for each target anchor X_{i}^{T}, we sample \left[\tilde{X}_{ij}^{T}\right]_{j=1}^{n^{T}}\sim{q_{i}^{T}} in the ball B_{\epsilon}\left(X_{i}^{T}\right) with the \emph{density function proportional} to \exp\left\{ \lambda\alpha s\left(X_{i}^{T},\bullet;\psi\right)\right\}.

To sample the particles from their local distributions, we use Stein Variational Gradient Decent (SVGD) [4] with a RBF kernel with kernel width \sigma. Obtained particles  \tilde{X}_{kij}^{S} and \tilde{X}_{ij}^{T} are then utilized to minimize the objective function for updating \psi=(\phi,\theta). Specifically, we utilize cross-entropy for the classification loss term \ell and the symmetric Kullback-Leibler (KL) divergence for the local regularization term s\left(X,\tilde{X};\psi\right) as  \frac{1}{2}KL\left(f_{\psi}\left(X\right)\Vert f_{\psi}\left(\tilde{X}\right)\right)+\frac{1}{2}KL\left(f_{\psi}\left(\tilde{X}\right)\Vert f_{\psi}\left(X\right)\right).

Finally, the global-regularization function of interest r^{g}\left(\left[X_{ki}^{S}\right]_{k,i},\left[X_{i}^{T}\right]_{i};\psi\right) is defined accordingly depending on the task.

3. Experiments

To demonstrate the effectiveness of our proposed method, we evaluate its performance on various experiment protocols, including DG, DA, SSL, and AML. We tried to use the exact configuration of optimizers and hyper-parameters for all experiments and report the original results in prior work, if possible.

 

3.1 Experiments for domain generalization

Table 1: Single domain generalization accuracy (%) on CIFAR-10-C and CIFAR-100-C datasets with different backbone architectures. We use the bold font to highlight the best results.

Table 1 shows the average accuracy when we alternatively train the model on one category and evaluate on the rest. In every setting, GLOT-DR outperforms other methods by large margins. Specifically, our method exceeds the second-best method ME-ADA by 3.2\% on CIFAR-10-C and 3.4\% on CIFAR-100-C.  The substantial gain in terms of the accuracy on various backbone architectures demonstrates the high applicability of our GLOT-DR.

Furthermore, we examine multi-source DG where the classifier needs to generalize from multiple source domains to an unseen target domain on the PACS dataset. Our proposed method is applicable in this scenario since it is designed to better learn domain invariant features as well as leverage the diversity from generated data. We compare GLOT-DR against DSN, L-CNN, MLDG, Fusion, MetaReg, Epi-FCR, AGG, HE, and PAR. Table 2 shows that our GLOT-DR outperforms the baselines for three cases and averagely surpasses the second-best baseline by 0.9\%. The most noticeable improvement is on the Sketch domain (\thickapprox2.4\%), which is the most challenging due to the fact that the styles of the images are colorless and far different from the ones from Art Painting, Cartoon or Photos (i.e., larger domain shift).

Table 2: Multi-source domain generalization accuracy (%) on PACS datasets. Each column title indicates the target domain used for evaluation, while the rest are for training.

3.2 Experiments for domain adaptation

In this section, we conduct experiments on the commonly used dataset for real-world unsupervised DA – Office-31, comprising images from three domains: Amazon (A), Webcam (W) and DSLR (D). Our proposed GLOT-DR is compared against baselines: ResNet-50, DAN, RTN, DAN, JAN, GTA, CDAN, DeepJDOT and ETD . For a fair comparison, we follow the training setups of CDAN and compare with other works using this configuration. As can be seen from Table 2, GLOT-DR achieves the best overall performance among baselines with 87.8\% accuracy. Compared with ETD, which is another OT-based domain adaptation method, our performance significantly increases by 4.1\% on A\textrightarrow W task, 2.1\% on W\textrightarrow A and 1.6\% on average.

Table 3: Accuracy (%) on Office-31 (Saenko et al., 2010) of ResNet50 model (He et al., 2016) in unsupervised DA methods.

3.3 Experiments for semi-supervised learning

Sharing a similar objective with DA, which utilizes the unlabeled samples for improving the model performance, SSL methods can also benefit from our proposed technique. We present the empirical results on CIFAR-10 benchmark with ConvLarge architecture, following VAT’s protocol VAT, which serves as a strong baseline in this experiment.  Results in Figure 3 (when training with 1,000 and 4,000 labeled examples) demonstrate that, with only n^{S}=n^{T}=1 perturbed sample per anchor, the performance of LOT-DR slightly outperforms VAT with \sim0.5\%. With more perturbed samples per anchor, this gap increases: approximately 1\% when n^{S}=n^{T}=2 and 1.5\% when n^{S}=n^{T}=4. Similar to the previous DA experiment, adding the global regularization term helps increase accuracy by \sim1\% in this setup.

Figure 3: Accuracy (%) on CIFAR-10 of ConvLarge modelin SSL settings when using 1,000 and 4,000 labeled examples (i.e. 100 and 400 labeled samples each class). Best viewed in color

3.4 Experiments for adversarial machine learning

Table 4 shows the evaluation against adversarial examples. We compare our method with PGD-AT [6] and TRADES [3], two well-known defense methods in AML and SAT . For the sake of fair comparison, we use the same adversarial training setting for all methods, which is carefully investigated in [5].  We also compare with adversarial distributional training methods (ADT-EXP and ADT-EXPAM), which assume that the adversarial distribution explicitly follows normal distribution. It can be seen from Table  4 that our GLOT-DR method outperforms all these baselines in both natural and robustness performance.  Specifically, compared to PGD-AT, our method has an improvement of 0.8\% in natural accuracy and around 1\% robust accuracies against PGD200 and AA attacks. Compared to TRADES, while achieving the same level of robustness, our method has a better performance with benign examples with a gap of 2.5\%. Especially, our method significantly outperforms ADT by around 7\% under the PGD200 attack.

Table 4: Adversarial robustness evaluation on CIFAR10 of ResNet18 model. PGD, AA and B&B represent the robust accuracy against the PGD attack (with 10/200 iterations) (Madry et al., 2018), Auto-Attack (Croce and Hein, 2020) and B&B attack (Brendel et al., 2019), respectively, while NAT denotes the natural accuracy. Note that ⋆ results are taken from Pang et al. (Pang et al., 2020), while ⋄ results are our reproduced results.

3.4. Conclusion

Although DR is a promising framework to improve neural network robustness and generalization capability, its current formulation shows some limitations, circumventing its application to real-world problems. Firstly, its formulation is not sufficiently rich to express a global regularization effect targeting many applications. Secondly, the dual form is not readily trainable to incorporate into the training of deep learning models. In this work, we propose a rich OT based DR framework, named Global-Local Regularization Via Distributional Robustness (GLOT-DR) which is sufficiently rich for many real-world applications including DG, DA, SSL, and AML and has a closed-form solution. Finally, we conduct comprehensive experiments to compare our GLOT-DR with state-of-the-art baselines accordingly. Empirical results have demonstrated the merits of our GLOT-DR  on standard benchmark datasets.

References

[1] Blanchet, J. and Murthy, K. (2019). Quantifying distributional model risk via optimal transport. Mathematics of Operations Research, 44(2):565–600.

[2] Sinha, A., Namkoong, H., and Duchi, J. (2018). Certifying some distributional robustness with principled adversarial training. In International Conference on Learning Representations.

[3] Zhang, H., Yu, Y., Jiao, J., Xing, E., El Ghaoui, L., and Jordan, M. (2019). Theoretically principled trade-off between robustness and accuracy. In Proceedings of ICML, pages 7472–7482. PMLR.

[4] Liu, Q. and Wang, D. (2016). Stein variational gradient descent: A general purpose bayesian inference algorithm. In Lee, D., Sugiyama, M., Luxburg, U., Guyon, I., and

Garnett, R., editors, Proceedings of NeurIPS, volume 29.

[5] Pang, T., Yang, X., Dong, Y., Su, H., and Zhu, J. (2020). Bag of tricks for adversarial training. In International Conference on Learning Representations.

[6] Madry, A., Makelov, A., Schmidt, L., Tsipras, D., and Vladu, A. (2018). Towards deep learning models resistant to adversarial attacks. In International Conference on Learning Representations.

Overall

72 minutes

Hoang Phan, Trung Le, Trung Phung, Tuan Anh Bui, Nhat Ho, Dinh Phung

Share Article