AI

FedMargin - Federated Learning via Attentive Margin of Semantic Feature Representations

By Umberto Michieli Samsung R&D Institute United Kingdom
By Mete Ozay Samsung R&D Institute United Kingdom

Introduction

Let’s play a simple game. Open the photo gallery on your phone and briefly scroll your images, do you see some patterns and recognize the objects you like on the images? Now imagine the photo gallery of any of your friends, how would that look like? Pretty different, isn’t it? They may like dogs and you like cats, they may like mountains and you like beaches, they may like tennis and you like athletics, etc. Now imagine people living in another country or be very different ages than you; how would their photo gallery look like?
Each person has different preferences and habits, and lives in a constrained geographical region.

This makes life hard for AI models (e.g. deep learning models) that have been developed to empower our device capabilities. Moreover, users of these AI services often opt out from sharing private photos with e.g. with a central server entity, therefore limiting the availability of data that we can use to train AI models.

When training a deep learning model, we would like this to work well for every user, even when opting out from data sharing.

Figure 1.   Data observed at distributed IoT clients k∈K are non-i.i.d. and imbalanced. This represents a challenge for federated learning of vision models in IoT systems.

In this blog, we present a new method (FedMargin) published at the IEEE Internet of Things journal to cope with these challenges [1].

To address these challenges, we study Federated Learning (FL) in Internet of Things (IoT) systems, which enables distributed model training using a large corpus of decentralized training data dispersed among multiple IoT clients [2]. In this distributed setting, system and statistical heterogeneity, in the form of highly imbalanced, non-independent and identically distributed (non-i.i.d.) data stored on multiple devices, are likely to hinder model training. Existing methods aggregate models disregarding the internal representations being learned, which play an essential role to solve the pursued task especially in the case of deep learning modules. To leverage feature representations in an FL framework, we introduce a method, called Federated Margin (FedMargin), which computes client deviations using margins over feature representations learned on distributed data, and applies them to drive federated optimization via an attention mechanism. Local and aggregated margins are jointly exploited, taking into account local representation shift and representation discrepancy with a global model.

In addition, we propose three methods to analyse statistical properties of feature representations learned in FL, in order to elucidate the relationship between accuracy, margins and feature discrepancy of FL models. In experimental analyses, FedMargin demonstrates state-of-the-art accuracy and convergence rate across image classification and semantic segmentation benchmarks by enabling maximum margin training of FL models. Moreover, FedMargin reduces uncertainty of predictions of FL models compared to the baseline. In this work, we also evaluate FL models on dense prediction tasks, such as semantic segmentation, proving the versatility of the proposed approach.

Motivation: Why are Semantic Representations Important?

Representation learning is a prosperous technique to address complex computer vision tasks, such as object recognition and image segmentation [3]. In this paradigm, a model is trained to learn rich and explanatory feature representations of its input, and learned representations are employed by task specific predictors (e.g., classifiers or detectors).
In prototype representation learning, the focus is on obtaining a few exemplars of feature embeddings representative of the available data.

Prototypical representations have been successfully adopted in multiple fields, such as few-shot image classification [4] and semantic segmentation [5], object recognition [6], domain adaptation [7] and continual learning [8] tasks.
Differently from those works, we employ class feature prototypes to derive representation margins [9, 10] and to formulate a weight attention mechanism for FL global model aggregation.

Figure 2.  Illustrative feature extraction (coloured dots) and prototype identification (black silhouettes)

In short, learned representations are crucial to understand what models have encoded into weights, and we argue that when aggregating models from different clients, learned representations play a key role.

Federated Learning Setup

In an FL system, clients optimize a local model on the local dataset to learn feature representations useful to perform an end task (e.g., object recognition or semantic segmentation).

In centralized FL systems, a central server coordinates optimization of a set of parameters of an aggregated model by minimizing a global learning objective without sharing local datasets.
Since the server does not have access to local user data, the global optimization step is a weighted sum of the local optimization steps.

Federated Averaging (FedAvg) [2] is a benchmark federated optimization algorithm widely used to solve the minimization problem above.

In FedAvg, a subset of clients are selected at each federated round. Selected clients download the aggregated model from a central server, perform local optimization minimizing an empirical objective and send the solution back to the server. The server averages the solutions of the clients with weights proportional to the size of the local datasets.

The procedure is iterated for multiple federated rounds to reach a final aggregate model.

Our Approach: Margin-guided Federated Learning

We refer the reader to our paper [1] for further details on the theoretical motivation behind our approach, which is reported here in a short form.

Feature representations have been successfully employed in various computer vision tasks [8, 4]. In this work, we employ margins of prototypes for federated optimization of vision models. Our margin guided federated optimizer (FedMargin) is motivated by the results obtained from the recent theoretical and experimental analyses of generalization capacity of latent class-conditional prototypes.

At each round and client, a local model (typically composed of encoder and decoder modules) is computed. Each input sample is encoded and then fed to the classifier to retrieve class-wise probability scores. Features corresponding to the same class are then averaged to construct local latent class-conditional prototypes (i.e., representing the centroid of the encoded samples of a given class and time).

To guide the optimization, we rely on a combination of two clues derived from displacement of prototypes:
       1.     Local Prototype Margin (LPM) measures deviation of on-client prototypes before and after local training.
       2.     Aggregate Prototype Margin (APM) measures deviation of aggregate prototypes from local prototypes.

As a measure for displacement, we embraced the margin theory [9, 10], in which PMs measure the distance between features and class decision boundaries. Here, instead, we aim to measure change of semantic representations among clients across different rounds for FL.

Figure 3.  Intuition of margin of features. Large margin is a desirable property when designing deep learning models.

Therefore, we propose a novel semantic PM.

In FL, we employ SPMs in two cases, LPM and APM, as follows:

The LPM is defined by the margin between the local class prototypes before and after local training, hence measuring their change due to the local training epochs.

The APM is defined to measure discrepancy between the local and the aggregate set of prototypes, which is defined as the weighted average of all the local prototypes received from the clients.

We remark that APM requires transmission of prototypes from clients to server. However, this does not raise privacy issues since prototypes represent only an averaged statistic over all local data of already compressed feature representations, nor large communication overhead, as the size of prototypes is negligible compared to the model size. While local deviation measured by LPM gives a hint of how much a model adapts its inner representation for each class, server-side deviation measured by APM tells how much a local model changes its inner representations with respect to the prototypical representations aggregated over previous rounds and clients.

Federated Attention using Prototypes. Client deviations and respective weight attention vectors are computed starting from LPM and APM.

Intuitively, each attention vector represents a measure of client drift [11]: as prototypes computed using a model on a local client deviate from reference prototypes in terms of margin (either locally or on server), higher attention is applied on their weights and vice-versa. In other words, we use prototype displacement as a proxy for measuring deviations of the client models during local training stages.

We remark that, according to our definition, if a client is not able to build reliable latent representations (low margin), then contribution of its model to computation of global weights is considered less during aggregation.

Finally, federated attention vectors are used to aggregate local weights at each round.

Our proposed FL method is called Federated Margin (FedMargin), and, differently from existing methods, it drives the federated optimization towards margin maximization of the learned feature representations, which is shown to correlate to higher accuracy.

Experimental Results

We refer the reader to our paper [1] for full experimental details and results.

We evaluated on four image classification (one of which is generated synthetically) and two semantic segmentation benchmarks.

We compare FedMargin with several FL baseline approaches that are related to our work in terms of addressed problem. We include FedAvg [2] as the starting baseline, and the fairness policy FairAvg [12], especially useful in case of heterogeneous datasets.

All the results of the analyzed methods are reported in the same experimental configurations and with toleration of partial workload enabled, unless otherwise specified, to allow for fair comparisons.

Table 1 shows the testing accuracy and the training loss values obtained using the aggregate models at the final round on the four classification datasets employed in our work. Additionally, the table shows the number of rounds required to reach 80% of the accuracy achieved by the centralized training (#R@80%). The evaluation is performed across three configuration of partial workload (i.e., stragglers) δ∈{0%,50%,80%}, and the average results along with the standard deviation are shown.
We observe that FedMargin robustly outperforms the competitors on every dataset.

Finally, we claim that FedMargin also shows better convergence rate than the competing approaches. We examine this claim in the last block of Table 1: our approach converges much faster (i.e., in fewer rounds) than compared approaches to the value indicated by 80% of the final accuracy score achieved by the centralized training. Similar to before, we report in Figure 4 the relative efficiency score of compared approaches against FedMargin. Therefore, FedMargin significantly decreases the number of communication rounds needed, reducing energy consumption in case a target accuracy is set. Our approach is faster than the second best approach (i.e., FedProx) by 18%. Furthermore, FedMargin is on average 17% faster to converge than the best second fastest approach across every dataset (i.e., FedProx on Synthetic and FEMNIST, FeSEM on MNIST, and FedRep on CelebA).

Table 1.   Final mean and standard deviation of: testing accuracy (%), training loss, and number of rounds to reach 80% of the accuracy level achieved by the centralized training (#R@80%). Evaluation on multiple classification datasets and methods, computed over δ∈{0%,50%,80%}. Centralized accuracy are 78.5, 99.0, 99.4, 92.6, and losses are 0.33, 0.00, 0.00, 0.15 for Synthetic Data, MNIST, FEMNIST and CelebA.

Figure 4.   Relative rescaled accuracy and efficiency gaps of compared approaches against the proposed FedMargin method.

Quantitative Metrics on Federated Representations

To properly investigate the effect of the proposed method, we moved to quantitatively analyse the effect of the margin-guided aggregation mechanism on the feature space shaping, defining a couple of metrics. The results reported here refer to the case in which no stragglers are present (i.e., δ=0), thus the same number of models are aggregated by each compared approach. This scenario allows to appreciate the sole effect of margin-driven aggregation on the statistics of the learned representations.

In Figure 3, we show the change of Aggregate Mean Margin (AMM, defined as the margin among the aggregate set of prototypes) for different optimizers and datasets during training in FL.
FedMargin achieves higher AMM compared to other optimizers. This is a direct consequence of a better shaping of latent representations with improved class-discrimination acting as regularizer for learning meaningful feature representations similar to centralized training. This is shown to correlate with accuracy.

In Figure 4, we show the change of Maximum Mean Discrepancy [19] (MMD, computed between features extracted by an FL algorithm and features generated by centralized training. Overall, we observe that the distributions of features learned using FedMargin are consistently more similar to those learned in centralized training than FedAvg. Therefore, FedMargin converge to models producing internal representations similar to the ideal case (i.e., the centralized approach). Last, we also note that FedProx can achieve some latent regularization thanks to the proximal term, however it is robustly surpassed by our proposed FedMargin.

Figure 5.   Per-round AMM on image classification datasets

Figure 6.   Per-round MMD on image classification datasets

Experiments on Semantic Segmentation Benchmarks
Differently from image classification, segmentation task is more challenging as it involves dense predictions and highly class-imbalanced datasets. Altogether, these circumstances make aggregating local models even more severe.

We start by analysing the effect of i.i.d. structure (i.i.d.-ness) of data on mIoU of federated segmentation models. For this purpose, we distribute two datasets among clients using the Dirichlet distribution with parameter α. Then, we compare the baseline FedAvg and our FedMargin.

The results in Figure 5 show the mIoU curves with different learning settings, in order to show the relationship between convergence of models and i.i.d.-ness of data.

Note that, as the non-i.i.d.-ness of distributed data increases by lower α, data heterogeneity and client drift increase.

FedMargin improves mIoU and robustness compared to FedAvg on every setup, and especially on highly non-i.i.d. data, where class-conditional representations on certain remote clients could be unreliable due to the non-i.i.d. partitioning (where only few samples for particular classes are observed on certain clients).

Figure 7.   Change of mIoU on segmentation data distributed using different α values. Evaluation is performed across δ∈{0%,50%,80%}.

To qualitatively analyse the effect of our approach, in Figure 6, we visualise segmentation and entropy maps of three sample images comparing the final aggregate models of FedAvg and FedMargin on different data splitting configurations (i.e., for different values of α). In particular, we show the predicted segmentation map, the entropy map of the final softmax layer, and the entropy map of the intermediate features.

Looking at the overall picture, we can appreciate a general improvement when going from more non-i.i.d. to more i.i.d. data, as the complexity of the optimization decreases.

Figure 8.   Qualitative analyses of representations learned with FedAvg and FedMargin for three non-i.i.d. to i.i.d. configurations. We show output maps (first row), softmax-level entropy maps (second row), and feature-level entropy maps (third row).

Output segmentation maps improve when data are more i.i.d., better resembling segmentation maps produced by centralised training. FedMargin produces significantly better segmentation maps for more non-i.i.d. data (α=0.01) compared to FedAvg: more correct class identification and better objects shaping. The ability to distinguish between class ambiguity is the direct consequence of a better latent space organization and regularization that FedMargin achieves by maximizing prototype margin.

Second, we analyse that FedMargin provides less uncertainty on the chosen classification labels than FedAvg, as shown by the softmax-level entropy maps. We report the entropy map of the pixel-wise softmax probabilities of the final model (via pixel-wise Shannon entropy [20]).

Low entropy (dark blue) indicates a peaked distribution which is the reflection of high confidence of the network on its prediction, and vice-versa.

Ideally, the entropy should be low for every pixel. However, as we can observe from centralized training, contours of objects and certain regions of the images have high entropy due to uncertainty on the precise edge localization of the objects or due to intrinsic ambiguity with other classes (all considered animals have fur with similar pattern). Thus, we observe how FedMargin provides less uncertainty on the chosen classification labels, producing generally darker entropy maps, compared to FedAvg, especially on non-i.i.d. data.

Last, we report the feature-level entropy maps upsampled to match input resolution, which measures how representative a feature is at each pixel location. To compute it, features are first normalized, such that the sum over the channels at each low-resolution pixel location is 1 (i.e., in order to consider them as probability vectors), and then we compute their entropy. In this case, the entropy measures how representative a feature is at each pixel location. Ideally, features corresponding to the desired class should be well activated so that the decoder can discriminate between them and assign the correct label: this is the case of centralized training where features corresponding to (certain parts of) the object class are bright (i.e., high entropy denoting many activated patterns). We observe that FedMargin produces a feature-level entropy map which is more similar to centralized training than the map produced by FedAvg (particularly visible for low values of α).

Conclusions

We introduced FedMargin, a distributed machine learning paradigm for vision models that can handle IoT clients characterized by system and statistical heterogeneity.

Previous approaches disregard internal representations to aggregate model weights. FedMargin, instead, computes client deviations based on the margin of class-conditional representations, and uses them to drive federated optimization by means of an attentive mechanism. We perform an extensive analysis of the proposed method, where we investigate statistical properties of feature representations learned in FL according to multiple metrics based on margins and feature discrepancy of FL models. Moreover, the experimental analyses across a suite of federated datasets on both image classification and semantic segmentation demonstrated the effectiveness of our framework. In particular, we established a new benchmark on federated semantic segmentation task outlining a new research direction.

References

[1] U. Michieli, M. Toldo and M. Ozay, "Federated Learning via Attentive Margin of Semantic Feature Representations," IEEE Internet of Things journal, 2022.

[2] B. McMahan, E. Moore, D. Ramage, S. Hampson and B. A. y Arcas, "Communication-efficient learning of deep networks from decentralized data," in Artificial Intelligence and Statistics (AISTATS), 2017.

[3] Y. Bengio, A. Courville and P. Vincent, "Representation learning: A review and new perspectives," IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI), vol. 35, no. 8, pp. 1798-1828, 2013.

[4] J. Snell, K. Swersky and R. Zemel, "Prototypical networks for few-shot learning," in Neural Information Processing Systems (NeurIPS), 4080-4090, 2017.

[5] K. Wang, J. H. Liew, Y. Zou, D. Zhou and J. Feng, "Panet: Few-shot image semantic segmentation with prototype alignment," International Conference on Computer Vision (ICCV), 2019.

[6] Y. Shicheng, W. Ying, H. Lianghua and M. Zhou, "Sparse Common Feature Representation for Undersampled Face Recognition," IEEE Internet Things Journal, vol. 8, no. 7, pp. 5607-5618, 2021.

[7] M. Toldo, U. Michieli and P. Zanuttigh, "Unsupervised Domain Adaptation in Semantic Segmentation via Orthogonal and Clustered Embeddings," in Winter Applications on Computer Vision (WACV), 2021.

[8] U. Michieli and P. Zanuttigh, "Continual Semantic Segmentation via Repulsion-Attraction of Sparse and Disentangled Latent Representations," in Computer Vision and Pattern Recognition (CVPR) conference, 2021.

[9] K. Crammer, R. Gilad-Bachrach, A. Navot and N. Tishby, "Margin analysis of the LVQ algorithm," in Neural Information and Processing Systems (NeurIPS), 2002.

[10] D. Nova and P. A. Estevez, "A review of learning vector quantization classifiers," Neural Computing Applications, vol. 25, no. 3, pp. 511-524, 2014.

[11] S. P. Karimireddy, S. Kale, M. Mohri, S. Reddi, S. Stich and A. T. Suresh, "SCAFFOLD: Stochastic controlled averaging for federated learning," in International Conference on Machine Learning (ICML), 2020.

[12] U. Michieli and M. Ozay, "Are All Users Treated Fairly in Federated Learning Systems?," in Computer Vision and Pattern Recognition (CVPR), Workshops, 2021.

[13] M. G. Arivazhagan, V. Aggarwal, A. K. Singh and S. Choudhary, "Federated learning with personalization layers," in arXiv preprint arXiv:1912.00818, 2019.

[14] L. Collins, H. Hassani, A. Mokhtari and S. Shakkottai, "Exploiting shared representations for personalized federated learning," in International Conference on Machine Learning (ICML), 2021.

[15] M. Xie, G. Long, T. Shen, T. Zhou, X. Wang, J. Jiang and C. Zhang, "Multi-center federated learning," in arXiv preprint arXiv:2108.08647, 2021.

[16] S. Ji, S. Pan, G. Long, X. Li, J. Jiang and Z. Huang, "Learning private neural language modeling with attentive aggregation," in International Joint Conference on Neural Networks, 2019.

[17] T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar and V. Smithy, "Feddane: A federated newton-type method," in Conference on Signals, Systems, and Computers, 2019.

[18] T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar and V. Smith, "Federated optimization in heterogeneous networks," in Conference on Machine Learning and Systems (MLSys), 2020.

[19] A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Scholkopf and A. Smola, "A kernel two-sample test," in The Journal of Machine Learning Research (JMLR), 2012.

[20] W. Wan, J. Chen, T. Li, Y. Huang, J. Tian, C. Yu and Y. Xue, "Information entropy based feature pooling for convolutional neural networks," in International Conference on Computer Vision (ICCV), 2019.