STTA: enhanced text classification via selective test-time augmentation

View article
PeerJ Computer Science

Introduction

Pre-trained language models have demonstrated superior performance on various natural language processing (NLP) tasks (He, Gao & Chen, 2023; Sanh et al., 2019; Liu et al., 2019; Wang, Wang & Yang, 2022), leading to a majority of research efforts focusing on improving the performance of the model during training. These efforts include using larger models, various forms of model structures (He, Gao & Chen, 2023), adversarial training  (Liu et al., 2020), and data augmentation  (Fang et al., 2022; Ren et al., 2021). Although some of these methods can bring improvements, they usually require additional training or cumbersome hyperparameter tuning, and even some methods try to obtain negligible improvements at a significant computational cost, which is obviously impractical. These methods attempt to improve the performance of the model from different perspectives during training. However, few studies have focused on improving the performance of the model during inference.

Typically, data augmentation is used during model training by adding transformed copies of each example to expand the dataset. While data augmentation does not require tedious hyperparameter tuning, it often requires additional model training. However, recent work has shown that data augmentation can also be used during inference to obtain greater robustness  (Shanmugam et al., 2020; Cohen, Rosenfeld & Kolter, 2019a), improved accuracy  (Matsunaga et al., 2017a; Lyzhov et al., 2020), or estimates of uncertainty  (Conde et al., 2023; Conde & Premebida, 2022; Wang et al., 2019), which is known as test-time augmentation (TTA). TTA is a general method to obtain ”smooth” model predictions by aggregating predictions of several transformed versions of a given input. It can be applied to any model and makes no assumptions about the model architecture or training method. For example, predictions of various rotated and scaled versions of a test image can be averaged, so that the final prediction is robust to any single adverse rotation or scaling.

TTA has been widely utilized in computer vision (CV) tasks and has demonstrated remarkable achievements. Previous TTA research mainly concentrate on how to design a better aggregation strategy for the predictions of augmented samples. For instance, Shanmugam et al. (2020) proposed to aggregate the augmentation through a learnable neural network, while Lyzhov et al. (2020) introduced a greedy policy search method to learn the strategy of data augmentation at test time based on the predictive performance on the validation set. Nevertheless, these TTA methods cannot be directly applied to the text domain and require additional access to the source data. Despite its success in CV, the exploration of TTA’s potential in NLP remains nascent. One of the intuitive reasons obstructing TTA research in NLP is the lack of a clear understanding regarding which augmentations should be applied to the text input during testing. Text-based data augmentation methods often involve more drastic transformations compared to image-based augmentations, frequently leading to model mispredictions and diminishing the effectiveness of previous TTA methods. Conversely, in CV, there exists a set of “standard” TTA techniques, such as rotation, scaling, and translation, which are typically perceived as label-preserving and still convey crucial visual information about the depicted object or scene. In contrast, text label-preserving transformations are often task-specific, posing challenges in maintaining intact labels. For instance, methods like words deletion and words position swapping may compromise label preservation.

Learning the choice of augmentation and learning the aggregated augmented predictions usually require access to the source data, which makes new assumptions and introduces additional inference latency, and often incurs unaffordable computational resource costs. As exemplified by the partial-LR method listed in Table 1, it necessitates additional labeled source data for training the learnable network. However, due to the privacy or legal restrictions, such as identity information, patient data, etc., the source data is often inaccessible. Furthermore, due to the lack of effective sample identification and selection mechanisms (e.g., Mean, Max, Hard Vote in Table 1), when facing the augmented samples with large noise generated by unstable text augmentation policies, the aggregated prediction may have a large bias compared to the ground truth. This is often the key obstacle that limits the effective application of TTA in NLP. Despite the popularity of TTA in CV, there is a lack of dedicated research on TTA in NLP. Thus, there is an urgent need for a new TTA method to promote the development of TTA in the NLP field.

Table 1:
The difference between our proposed STTA and related TTA settings.
Method Source free Online decision Selective aggregation
Mean Yes Yes No
Max Yes Yes No
Smote Yes Yes No
Hard Vote Yes Yes No
Partial-LR No No Yes
STTA Yes Yes Yes
DOI: 10.7717/peerjcs.1757/table-1

In this work, we mainly focused on: (1) understanding which data transformation version of the prediction changes TTA and, (2) based on these insights, how to reduce the risk of TTA aggregation predictions and further stabilize TTA improvements based on NLP classification. We first empirically analyze common and representative data augmentation methods and discuss their impact on TTA policy design. After the analysis, we propose an online selective TTA method, called “STTA” (Selective Test-Time Augmentation), which divides the transformation samples according to experience and risk. We believe that different versions of the transformed samples contribute differently to the aggregated predictions, so different augmented samples should play different roles during testing. We divide the augmented samples into four roles based on similarity and confidence: gold, reward, potential, and risk. It is important to note that the factors influencing TTA results primarily depend on designing the augmentation policy and effectively aggregating the predictions of augmented samples. In this study, our primary focus is on the latter aspect.

Overall, the research results show that the proposed method is lightweight and easy to implement. Selective TTA can provide significant improvements in the accuracy of text classification and is almost free in terms of computational overhead. Our contributions are as follows:

Main contributions

  • We analyzed and empirically verified why TTA is sensitive to some data augmentation methods and revealed why some data augmentation methods lead to erroneous predictions.

  • We proposed a simple yet effective online TTA method, which selectively aggregates augmented predictions based on risk and reward criteria, thus effectively reduces the bias caused by abnormal transformation samples and enhances the robustness of the model.

  • Our proposed method is “plug-and-play”, can be applied to any existing model without the need for hyperparameter tuning or model modification, and can seamlessly collaborate with other robustness methods.

Experimental Results

Related Work

Test-time augmentation

TTA is a technique applied to a trained model during testing, where multiple augmented samples are generated for each original sample. The average prediction over the augmented samples is then used as the aggregated result to improve the final prediction of the model. Although data augmentation is typically applied during model training, it can also be used during prediction. TTA has been widely demonstrated to enhance model accuracy and robustness (Krizhevsky, Sutskever & Hinton, 2012; Matsunaga et al., 2017b; Cohen, Rosenfeld & Kolter, 2019b), address distribution shift issues (Zhang, Levine & Finn, 2022), and defend against adversarial attacks (Prakash et al., 2018; Gao et al., 2020). Researchers have proposed various TTA methods in different domains, including image segmentation (Moshkov et al., 2020), text grammar correction (Yang et al., 2022), text classification (Lu et al., 2022), audio-text retrieval (Kim et al., 2022), theoretical research (Kimura, 2021; Kim, Kim & Kim, 2020), and uncertainty estimation (Conde et al., 2023; Conde & Premebida, 2022).

Although TTA has been widely studied in many tasks (Shanmugam et al., 2020; Lyzhov et al., 2020; Guo et al., 2017; Kimura, 2021), the fact remains that the source data is assumed to be accessible and there is a lack of in-depth insights into data augmentation policies, which is often impractical in reality. For example, both Shanmugam et al. (2020); Lyzhov et al. (2020) utilize a labeled source dataset to learn aggregation. On the other hand, the standard TTA method and the Max method select the maximum logit across all augmented samples, and Smote (Fernández et al., 2018) only considers the nearest sample interpolation to generate new samples, which is prone to interference from noisy augmented samples, resulting in significant bias in aggregation prediction. In this work, we propose a simple yet effective TTA method, called STTA, which effectively combines confidence and similarity to select reliable augmented samples for aggregation.

Ensembling

Ensemble deep learning models use multiple neural networks instead of a single one to compute predictions to improve the performance of various machine learning problems. Typically, ensembling involves obtaining a set of trained neural network models, each of which has a different algorithm or variant, and averaging the predictions for each test object. The various approaches ranging from traditional methods such as Bagging (Breiman, 1996), Boosting (Schapire, 2003), Stacking (Ting & Witten, 1997), to the latest methods homogeneous ensembles (Ganaie et al., 2022) and heterogeneous ensembles (van Rijn et al., 2018; Fang et al., 2021), have resulted in better performance ensemble models.

Sub-ensemble selection

Ensembling has been demonstrated to enhance overall performance and robustness. However, training multiple models for ensembling requires additional computational overhead (Shen, He & Xue, 2019). Even though a single model is used for TTA, it makes sense to view TTA as an ensemble of different models. This is because each sample of the augmentation sub-policy in TTA can be regarded as a new test sample. More specifically, multiple models generate multiple different predictions of the same sample, and TTA employs a single model to generate multiple different predictions for multiple augmented samples. When the predictions of multiple augmented samples in TTA are aggregated, it approximates the ensemble of multiple models. A related instance of this concept is found in the work of Fern & Lin (2008), who proposed to select the optimal clustering result from multiple clustering models for aggregation. This demonstrates the notion of TTA and similar aggregations of diverse augmented samples to improve overall prediction accuracy.

Problem Definition and Motivation

Let us suppose we have a C-class classification model fθ:xt → ΔC with parameters θ trained on labeled source training data D S = x t i , y i i N S , where xti and yi represent the input and corresponding label, respectively. During inference at time step t, the model can only access the current time step data xt, and the unlabeled target domain data xt is streamed in a sequential manner. To do this, a batch of N augmented samples x ̃ a i x i 1 , N is generated from a uniform distribution U A of augmentation functions a A . More formally, the model fθ should be able to make an online decision p ̃ t instead of pt based on the current input xt by aggregating the augmented samples x ̃ . Formally, the standard TTA method can be formulated as follows: p ̃ t = 1 N i = 1 N σ f θ x ̃ i where fθ(xi) ∈ ΔC and Δ C = p t 1 , , p t C 0 , 1 C : j = 1 C p t j = 1 represents a probability simplex. σ:ℝd → ΔC represents the softmax function for each xt to approximate the probability distribution of pt.

Motivation

As a matter of fact, TTA is still in its infancy in the NLP field, and its effectiveness remains an unresolved issue. It mainly involves two obstacles: (1) It is still unclear which data augmentation methods preserve labels at test time, so simply averaging the predictions of all samples is prone to prediction bias. (2) Previous TTA methods attempt to learn aggregation through neural networks. On the one hand, they assume that the source data is accessible and labeled, as shown in Table 1. On the other hand, previous TTA methods may fail due to inappropriate text augmentation methods. In short, the urgent need has prompted us to propose a selective TTA method.

Why selectively aggregate augmented samples?

To gain a deeper understanding of the reasons underlying the ineffectiveness of TTA caused by text data augmentation, one intuitive approach is to visualize the importance levels of different words in sentences to demonstrate the variations caused by different augmentation strategies. Therefore, we first select the most representative text data augmentation methods, namely RWD, RPI, RWS, RWI, and RWSR, with detailed information as follows:

  • RPI (Random punctuation insertion) (Karimi, Rossi & Prati, 2021): Inserting punctuation marks randomly within a text sequence with a probability of 0.3.

  • RWSR (Random word synonym Replacement) (Wei & Zou, 2019): Randomly replacing m non-stop words in a sentence with synonyms with a probability of 0.1.

  • RWI (Random word insertion) (Wu et al., 2020): Inserting a synonym of a randomly selected non-stop word at a random position in the sentence with a probability of 0.1.

  • RWS (Random word swap) (Wei & Zou, 2019): Swapping the positions of two words in a sentence randomly with a probability of 0.1.

  • RWD (Random word deletion) (Bayer, Kaufhold & Reuter, 2022): Randomly deleting words from a sentence with a probability of 0.1.

Following the approach of Tenney et al. (2020), we employed integrated gradients (IG) (Sundararajan, Taly & Yan, 2017) to map the importance of different words in the input xt. As shown in Fig. 1, when the input sample is “the story and characters are nowhere near gripping enough,” with the true label being “negative,” we observed that using the RWD method resulted in a transformed sample of “story characters are near gripping enough,” involving word deletion and removing the crucial feature “nowhere.” This led the model to predict the sample as “very positive,” which contradicts the original label. During the application of the RPI operation, an error occurred by mistakenly inserting a question mark (“?”) at the end of the sentence. Consequently, the meaning of the sentence shifted from a declarative statement to an interrogative one. Although such a change may enhance the semantic richness of the sentence during the training process, it is not suitable for the inference stage.

Visualization of different data augmentation samples from input xt.

Figure 1: Visualization of different data augmentation samples from input xt.

Here, RWD denotes the word deletion operation, RPI denotes the punctuation insertion operation, RWS denotes the word swap operation, RWI denotes the word insertion operation, and RWSR denotes the word synonym replacement operation.

Thus, if the predictions of all augmented samples are simply aggregated as the final prediction, it will greatly increase the risk of aggregation prediction error, which will lead to a significant drop in the performance of the model. Therefore, it is necessary to selectively aggregate augmented samples.

Methodology

In this section, we formally define our proposed method STTA. We first introduce how data augmentation methods are applied at test time. Then we introduce the method of identifying augmented samples based on similarity and confidence score step by step. According to the high and low degrees of confidence and similarity from two perspectives, we can divide the augmented samples into four distinct roles—Gold, Bonus, Potential, and Risk. Based on the identified roles of the augmented samples, we design the selective test-time data augmentation method (STTA). The pseudo-code and overall process for STTA are presented in Algorithm 1 and Fig. 2, respectively.

 
 
 
Algorithm 1 Proposed Approach STTA. 
    Input:    A source pre-trained model fθ; 
             Target domain data DT = {xt}Tt=1; 
Require: Number of augmentations N; 
             Confidence function C; Similarity function S; 
             Augmentation function A; 
    Output: Prediction  ~ pt ∈ Δk for each xt. 
  1:  for  t = 1,...,T do 
 2:       Sample N augmented data ~ xt,i = {xit}Ni=1  using augmentation function 
     A. 
  3:       Feed forward ~ xt,i to obtain the augmented logits matrix PN×C. 
  4:       Get the confidence and similarity quartile intervals Qconf and Qsim by 
     Equation 2 and Equation 8. 
  5:       Get the candidate set Wcandidates in Equation 9. 
  6:       Calculate the final prediction probability vector ~ pt instead of pt by Equa- 
     tion 12. 
  7:  end for    

An overview of the process of our proposed STTA method.

Figure 2: An overview of the process of our proposed STTA method.

Different colors represent different augmented samples. The color of the augmented sample is determined by the similarity between the augmented sample and the original sample and the confidence of the model prediction.

Data augmentation on test time

We first consider m types of augmentation to obtain N augmented samples, where N = n × m and n represents the number of augmented samples for each type of augmentation. During the inference at time step t, for a given test sample xt, we can obtain N augmented samples x ̃ t , i by applying m types of augmentation functions a i A to xt. Then, we can obtain the predicted probability distribution pt,i for each augmented sample x ̃ t , i as follows: p t , i = f θ x ̃ t , i Δ C where f θ x ̃ t , i represents the logit associated with the ith augmented sample x ̃ t , i . Then, we can define ∀i ∈ [1, n × m], p t , i = p t , i 1 , , p t , i C R C and then the augmented logits matrix PN×C can be defined as follows: P N × C = p t , i i = 1 , , N p t , 0 1 p t , 0 2 p t , 0 C p t , 1 1 p t , 1 2 p t , 1 C p t , N 1 p t , N 2 p t , N C

Similarity-based recognized

Suppose we have a set of augmented samples x ̃ t , i , where i ∈ [1, N]. Let gϕ:𝒳 → ℝd be the model encoder to map the augmented samples x ̃ t , i to a d-dimensional feature space ℝd. Then, we can obtain the feature representation of the augmented samples x ̃ t , i as follows: F N × d = g ϕ x ̃ t , i i = 1 , , N where g ϕ x ̃ t , i R d represents the feature representation of the ith augmented sample x ̃ t , i .

Then we feed the feature representation FN×d into the HNSW algorithm (Malkov & Yashunin, 2018) to calculate the similarity distance S between the augmented samples and the original sample as shown in Algorithm 2 .

 
 
Algorithm 2 Similarity-based Recognized 
Require: Augmented feature matrix FN×d 
Ensure: Similarity distance set S 
 1:  Initialize the graph G with a single node v0 and set v0 as the entry point 
  2:  Current level l ← 0 
  3:  for i = 1 to N do 
 4:       vi ← HNSW (FN×d, v0) 
  5:       Add vi to G and connect it to v0 
  6:       if i2l = 0 then 
 7:            l ← l + 1 
  8:       end if 
 9:  end for 
10:  Query feature q ← gϕ(xt) 
11:  Initialize the priority queue PQ with v0 
12:  Initialize the set of visited nodes V with v0 
13:  search ← 0 
14:  efSearch ← 100 
15:  while search < efSearch do 
16:       v ← PQ.pop() 
17:       for vi ∈ v do 
18:            if vi / ∈ V then 
19:                  V ← V ∪ vi 
20:                  PQ ← PQ ∪ vi 
21:            end if 
22:       end for 
23:       search ← search + 1 
24:  end while 
25:  S←{d(q,vi)|vi ∈ V } 
26:  return S    

S indicates the similarity distance between xt and each augmented sample x ̃ t , i , so we can divide the augmented samples into quartile intervals by similarity distance. Q s i m P N × C : q = 1 4 S q S 1 , S 2 , S 3 , S 4 where S q represents the set of augmented samples in the q-th quartile interval based on similarity distance.

Confidence-based recognized

The confidence of the model prediction is defined as the maximum probability value of the predicted probability distribution of the model. For each augmented sample x ̃ t , i , the confidence score can be obtained using the following formula: c o n f x t , i = max j = 1 , , C p t , i j

Then, we can obtain the set of confidence scores for augmented samples, C , by sorting the confidence scores of the augmented samples and dividing them into quartile intervals. Q c o n f P N × C : g = 1 4 C g C 1 , C 2 , C 3 , C 4 where C g represents the set of augmented samples in the q-th quartile interval based on confidence score.

Sample roles recognition

Based on the two perspectives mentioned above, we can divide the all augmented samples into four different roles according to the high and low degrees of confidence and similarity. The details are as follows:

  • Wgold refers to samples with high confidence and are the most similar to the original sample.These samples can effectively evaluate and improve the model’s generalization ability, making them valuable for aggregation. It can be formulated as candidates = x t i | x t i C 1 C 2 S 1 S 2

  • Wbonus refers to samples with high confidence and low similarity which can provide additional feature information for aggregation. It can be formulated as candidates = x t i | x t i C 1 C 2 S 3 S 4

  • Wpotential refers to samples with high similarity to the original sample and low confidence which may have some deviation in the augmentation process but still retain the features of the original sample. It can be formulated as candidates = x t i | x t i C 3 C 4 S 1 S 2

  • Wrisk refers to samples with low confidence and similarity to the original sample which lose the features of the original sample and are less valuable for aggregation. It can be formulated as candidates = x t i | x t i C 3 C 4 S 3 S 4

Note that although Wgold theoretically has a greater value contribution. However, this does not imply that Wrisk and Wpotential do not have the ability to contribute.In practical applications, the selection can be made based on specific circumstances. Here, we have selected Wgold, Wbonus, and Wpotential as candidate samples in all experiments.

Base on the above analysis, we can obtain the candidate set W candidates K for aggregation: W candidates K = x t i | x t i W gold W bonus W potential where K is the number of samples in the candidate set, the matrix of logits for the selected samples, denoted as PK×C, is defined as follows: P N × C A u g m e n t e d S a m p l e s Select P K × C = p t , i i = 1 , , K p t , 0 1 p t , 0 2 p t , 0 C p t , 1 1 p t , 1 2 p t , 1 C p t , K 1 p t , K 2 p t , K C

The final prediction probability vector p t ̃ instead of pt is obtained by aggregating the logits matrix PK×C and the original prediction probability vector p0 as follows: p ̃ t = ω p t , 0 + 1 ω 1 K σ P K × C with ω = max ω 0 , α : arg max i 1 , , k p ̃ t ω = arg max i 1 , , k p t , 0 where ω is the aggregation weight, σ:PK×C → Δk represents the softmax function, and α←0.5 since we consider both the original prediction and the aggregated prediction equally important. Although more advanced techniques exist to determine the optimal value of α. Nonetheless, our primary goal is to demonstrate the ability of our approach to mitigate bias through the memory bank.

Experiments

Setup

Benchmark datasets and models

All experiments were conducted on the pre-trained BERT-base (Devlin et al., 2018) and obtained the weights of the pre-trained model from Hugging Face Transformers (https://huggingface.co/bert-base-uncased). Then we fine-tuned it on seven benchmark datasets: MRPC (https://www.microsoft.com/en-us/download/details.aspx?id=52398) (Wang et al., 2018), Recognizing Textual Entailment (RTE) (https://huggingface.co/datasets/SetFit/rte) (Wang et al., 2018), SST-5 (Socher et al., 2013) (https://nlp.stanford.edu/sentiment/index.html), TREC-Fine, TREC-Coarse (https://cogcomp.seas.upenn.edu/Data/QA/QC/) (Li & Roth, 2002), SUBJ (https://github.com/facebookresearch/SentEval) (Conneau & Kiela, 2018),and TweetEval Emoji (https://github.com/cardiffnlp/tweeteval) (Barbieri et al., 2018).

Implementation details

In this article, we focus on the test time, so for the model tuning in the training stage, all our experiments follow the hyperparameter settings provided by the dataset official website. We use the model training code provided by Hugging face Transformers (https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification), and adopt Adam (Kingma & Ba, 2015) with initial learning rate 2e−5, batch size 32 and max sequence length 128. During test time, Partial-LR adpot Adam optimizer (Kingma & Ba, 2015) with a learning rate 2e−5, Smote select the same number of nearest neighbors as the augmented samples as k, Hard Vote select the most frequent prediction as the final prediction. For our method, no hyperparameter tuning is involved. All experiments were conducted on a single NVIDIA RTX A6000 GPU, and we performed the experiments with five different random seeds.

Compared methods

In this article, we mainly consider the following strong and representative baselines:

  • Baseline: Directly use the prediction of the original sample without TTA.

  • Mean (Krizhevsky, Sutskever & Hinton, 2017) (Standard TTA method): Average logit across all augmented samples. This is standard practice in TTA.

  • Max (Maximum Predicted Probability) (Guo et al., 2017): Maximum logit across all augmented samples; This baseline approach involves choosing the prediction with the highest level of confidence from a given set of predictions.

  • Hard Vote (Perikos & Hatzilygeroudis, 2016): Select the most frequent prediction across all augmented samples.

  • Smote (Synthetic Minority Over-sampling Technique) (Fernández et al., 2018): address class imbalance by synthesizing minority class samples.

  • Partial-LR (Partial Logistic Regression) (Shanmugam et al., 2020): learning N parameters for each augmentation sample to aggregate the prediction.

The Max, Partial-LR and Mean methods are extracted from Shanmugam et al. (2020), we cannot consider all the baselines because some of them are not applicable to text tasks. While these baselines reflect existing work, they are not the only methods for aggregating test-time augmented predictions. Therefore, we further introduce some representative algorithms that can aggregate sample prediction results at test time, such as Hard Vote and Smote.

Experimental Results

Effect of augmentation policies

The factors that affect TTA include data augmentation methods and how to aggregate predictions effectively. In this section, we will conduct extensive discussions on different data augmentation methods to demonstrate the effectiveness of our method.

Note that although more complex data augmentation methods may be more effective in theory, but they are not necessarily representative in practice. Due to the large number of possible combinations of the above data augmentation methods, we mainly consider the most representative data augmentation methods and then combine them to form six data augmentation policies based on insights from Lu et al. (2022), as shown in Table 2, where RPI, RWSR, RWI, RWS, and RWD represent random punctuation insertion, random word substitution with replacement, random word insertion, random word swap, and random word deletion, respectively. And when the data augmentation method is Aug. 6 (RWSR + RWI + RWS + RWD), the standard data augmentation method is equivalent to the TTA method given by Lu et al. (2022).

Table 2:
Composition of six different augmentation policies considered in our experiments.
Aug. 1:   RPI Aug. 4:   RWSR
Aug. 2:   RWI Aug. 5:   RWS
Aug. 3:   RWD Aug. 6:   RWSR + RWI + RWS + RWD
DOI: 10.7717/peerjcs.1757/table-2

Figure 3 illustrates a boxplot displaying the distribution of accuracy for different data augmentation policies. It can be observed that our method effectively adapts to various data transformations, even performing well in cases where semantic changes are likely to occur. This indicates that our approach successfully identifies anomalous samples from the batch data, preserving valuable contributions and thereby enhancing the accuracy of the model. In contrast, other methods lacking this capability experience a greater decline in performance when faced with data augmentation policies such as RWD.

Comparison of different data augmentation methods of TTA and report the average accuracy over all datasets.

Figure 3: Comparison of different data augmentation methods of TTA and report the average accuracy over all datasets.

Here, N = 32.

Additionally, we observe that in other methods, the outliers in the boxplot tend to be located in the region where the model’s net gain decreases. However, in our method, the outliers are generally situated in the region where the model’s net gain increases. This suggests that our approach has even greater potential to further improve model gains.

Although the gains achieved by our method may not be significant for some data augmentation policies, overall, our approach consistently demonstrates absolute improvements across all employed data augmentation policies.

Main results

Table 3 displays the absolute accuracy improvement results of all TTA methods on representative benchmark datasets for text classification. In general, most TTA methods fail to generate positive performance gains and, to varying extents, impair the model’s performance. Contrarily, our approach achieves significant performance improvements across all datasets, with an average increase of 0.50%. Notably, our approach demonstrates impressive performance gains of 0.72% and 0.70% on RTE and MRPC, respectively, indicating its effectiveness in handling diverse types of datasets.

Table 3:
The absolute accuracy (%) improvement of different TTA methods on 7 benchmark datasets over five random seeds.
Here, data augmentation policy is RWSR + RWI + RWS + RWD and N = 32.
Dataset SST-5 MRPC RTE SUBJ TREC coarse TREC fine TweetEval Emoji Avg.
Baseline 52.85 77.80 66.06 96.50 96.40 92.00 32.17 73.40
Ours 0.63 0.70 0.72 0.15 0.20 0.40 0.67 0.50
Mean 0.45 −0.52 −2.53 −0.40 0.10 0.60 −0.56 −0.41
Max −0.50 −0.52 −5.05 −0.45 −1.00 −10.60 −0.27 −2.63
Smote 0.18 −0.23 −2.17 −0.40 0.20 0.40 −0.74 −0.39
Hard Vote 0.14 −0.35 −2.53 −0.40 0.00 0.40 −0.97 −0.53
Partial-LR 0.32 −0.64 −1.44 −0.35 −1.40 0.20 −0.68 −0.57
DOI: 10.7717/peerjcs.1757/table-3

Notes:

Results for the proposed method are shown in bold.

In contrast, Partial-LR utilizes a learnable network to allocate weights to different samples, but it introduces an additional assumption that the source data is accessible. However, its performance improvement falls short of expectations. As discussed in ‘Related Work’, the unreasonable data augmentation operations may destroy the sample label, making it challenging for Partial-LR to learn the correct weight of each augmented sample.

In the context of uncertain predictions stemming from noisy samples, the baseline model may prove challenging in making robust predictions. Consequently, overconfident and erroneous predictions may arise. Such a scenario presents difficulty in selecting the appropriate high-confidence sample from a set of augmented samples when employing the Max method. This, in turn, results in a notable decline of 10.60% in performance on TREC Fine. And Hard Vote also faces this obstacle. Furthermore, it is worth noting that while Smote achieves the second-best performance on TREC Fine, i.e., 0.40%, its interpolation operation on the original samples may also be impaired by the presence of additional noisy samples. As a result, the problem of erroneous predictions may still persist.

It is worth noting that we observe that when the baseline model performs poorly, such as on the RTE and MRPC datasets, TTA achieves a greater improvement. As the baseline performance improves, when the model performs well, such as SUBJ, TREC Coarse and TREC Fine datasets with an accuracy of over 90%, the improvement of TTA will decrease. Specifically, when we turn our attention to the TREC Coarse dataset, which has a good baseline performance, the improvement of some TTA methods is even 0, such as Smote. This finding suggests that TTA is best applied when the baseline model performs poorly. Additionally, Mean can achieve the highest performance improvement on TREC Fine with a simple average, but it performs poorly on other datasets, which suggests that there is great potential for TTA, but there is a lack of reasonable ways to identify those noisy samples caused by inappropriate data augmentation methods. Overall, although our method does not significantly improve the datasets on which the model is already performing well, such as SUBJ, our method does not compromise any model performance. These results demonstrate that selectively aggregating predictions based on sample roles is effective.

Analysis of Label Correction and Corruption

Figure 4 illustrates the number of samples that have been corrected (model’s original incorrect prediction is changed to a correct prediction by TTA) or corrupted (model’s original correct prediction is changed to a incorrect prediction by TTA). While the standard TTA method may correct more incorrect predictions on certain datasets (e.g., SST-5), it generally leads to more label corruptions across most datasets. In fact, on some datasets (e.g., TREC Fine), the number of label corruptions is more than double the number of correct labels. This occurs because the standard TTA method fails to consider the distinctions between augmented samples and struggles to accurately determine their significance. Consequently, it often combines some augmented samples with negative impacts alongside other valid samples.

Number of corrected and corrupted samples by different TTA methods on different datasets.

Figure 4: Number of corrected and corrupted samples by different TTA methods on different datasets.

Here, data augmentation policy is RWSR + RWI + RWS + RWD and N = 32.

To overcome this disadvantage, Partial-LR uses regression methods to estimate the weights of different augmented samples, and then calculate the final prediction results based on these weights. However, it caused the number of corrupted labels to be more than twice the number of corrected labels in SUBJ, TREC Coarse, and TREC Fine datasets. Similarly, Smote also caused serious label corruption in SUBJ and TREC Coarse datasets. The reason for the large number of label corruptions is that the RWSR + RWI + RWS + RWD data augmentation policy used in the experiment easily generates abnormal noisy augmented samples. Partial-LR and Smote are difficult to effectively identify these samples, leading to a greatly increased likelihood of label corruption during the aggregation operation.

To summarize, due to the lack of the ability to distinguish the importance of augmented samples, the model’s aggregated prediction results are easily affected by bad augmented samples, resulting in incorrect prediction results.

Ablation Study

Necessity of both similarity and confidence

To verify the importance of both confidence and similarity in STTA, we labeled high (low) confidence samples as Ch(Cl) and high (low) similarity samples as Sh(Sl). Then we conducted the following ablation experiments:

  • STTA w/o Conf.: When discriminating the roles of the augmented samples, we only use similarity to divide them into two types: Sh and Sl.

  • STTA w/o Sim.: When discriminating the roles of the augmented samples, we only use confidence scores to divide them into two types: Ch and Cl.

  • STTA with Conf.&Sim.: When discriminating the roles of the augmented samples, we use both similarity and confidence scores to divide them.

From Fig. 5, the experimental results indicate that conducting experiments solely based on similarity or confidence both lead to a decline in model profit gained, while STTA w/o Conf. is worse than STTA w/o Sim., which indicates that similarity plays a greater role than confidence in aggregating predictions. Furthermore, the results also show that STTA with Conf. & Sim. indicates the best performance, which demonstrates the necessity of combining similarity and confidence to define the different roles of augmented samples.

Necessity of both similarity and confidence. Here, data augmentation policy is RWSR + RWI + RWS + RWD and N= 32.

Figure 5: Necessity of both similarity and confidence. Here, data augmentation policy is RWSR + RWI + RWS + RWD and N= 32.

Effect of different roles

To verify the necessity of different sample roles (Wgold, Wbonus, Wpotential and Wrisk) and the combination of different sample roles, we conducted experiments separately. As shown in Fig. 6, we observed that the gains of Wgold, Wbonus, Wpotential and Wrisk exhibit a linear downward trend, which confirms the effectiveness of our method in distinguishing valid samples. However, any single sample recognition role is not as good as the method (ours) that combines different roles, which indicates that it is necessary to combine different roles.

Effect of Wgold, Wbonus, Wpotential and Wrisk on SST-5.

Figure 6: Effect of Wgold, Wbonus, Wpotential and Wrisk on SST-5.

Here, data augmentation policy is RWSR + RWI + RWS + RWD and N = 32.

Effect of number of augmentations

Our analysis of the results presented in Fig. 7 aims to explore the impact of different quantities of data augmentation on the performance of TTA. Notably, our proposed method, STTA, exhibits a distinct linear positive correlation between the number of augmentations and model performance. Conversely, the standard TTA method does not yield any discernible advantage with increasing the number of augmentation, while the Max method exacerbates the decline in model performance. Although Partial-LR mitigates the detrimental effects as the number of augmentations increases, it still falls short of achieving positive gains.

Average absolute accuracy (%) improvement of various TTA methods with different numbers of data augmentations is obtained by averaging the results from different datasets and five random seeds.

Figure 7: Average absolute accuracy (%) improvement of various TTA methods with different numbers of data augmentations is obtained by averaging the results from different datasets and five random seeds.

Conclusion and limitations

In this article, we propose a Selective Test-Time Augmentation method, called STTA, which is a simple yet effective alternative to the standard TTA method. We aims to overcome the limitations of the standard TTA method and mitigate the sensitivity of the model to abnormal augmented samples by leveraging the role recognition of augmented samples. Unlike prior advanced TTA methods, STTA does not require access to any source data or additional training. Furthermore, STTA does not interfere with the training process of the backbone network and can be used in conjunction with other robust methods to further enhance the model’s performance. Furthermore, our proposed method is straightforward and efficient, and its plug-and-play implementation allows for seamless integration with any existing models

For future work, we plan to explore the following directions:

  • Investigate how to selectively apply TTA to only those samples that need to be corrected, rather than the entire test set, in order to minimize time and computational costs.

  • Design more reasonable and effective data augmentation methods, especially for test time, to further improve model performance.

  • Explore the application of STTA in other NLP tasks, such as neural machine translation and question answering, to evaluate its effectiveness and generalizability.

  • A potential risk with our method is that the confidence mechanism in Eq. (7) does not fully reflect the reliability of the augmented samples, where the neural networks are overconfident (Wei et al., 2022), although we further combine similarity as a judgment. We plan to consider using better methods for uncertainty estimation to evaluate the augmented samples (Ovadia et al., 2019).

Supplemental Information

The source code for our proposed STTA method.

The source code for our method and the startup file are explained in readme. main.py is the entry file, TTA Method. py contains different methods

DOI: 10.7717/peerj-cs.1757/supp-1

MRPC dataset.

This dataset is a port of the official mrpc dataset on the Hub. Note that the sentence1 and sentence2 columns have been renamed to text1 and text2 respectively. Also, the test split is not labeled; the label column values are always -1.

DOI: 10.7717/peerj-cs.1757/supp-2

SST5 dataset.

Stanford Sentiment Treebank with 5 labels: very positive, positive, neutral, negative, very negative

DOI: 10.7717/peerj-cs.1757/supp-3

SUBJ dataset.

It contains sentences with an annotation if the sentence describes something subjective about a movie or something objective.

DOI: 10.7717/peerj-cs.1757/supp-4

TREC-Coarse dataset.

The Text REtrieval Conference (TREC) Question Classification dataset contains 5500 labeled questions in training set and another 500 for test set.The TREC-Coarse dataset has 6 coarse class labels.

DOI: 10.7717/peerj-cs.1757/supp-5

TREC-Fine dataset.

The Text REtrieval Conference (TREC) Question Classification dataset contains 5500 labeled questions in training set and another 500 for test set.The TREC-Fine dataset has 50 fine class labels.

DOI: 10.7717/peerj-cs.1757/supp-6

RTE dataset.

The RTE dataset is a commonly used natural language inference task dataset used to evaluate the performance of a model in inference tasks (judging whether it is correct based on given premises and assumptions).

DOI: 10.7717/peerj-cs.1757/supp-7

Tweet eval emoji dataset.

The task of the tweet eval emoji dataset is to predict the corresponding emoji based on a given tweet text. Includes four categories: angel, joy, optimism, and sadness.

DOI: 10.7717/peerj-cs.1757/supp-8
  Visitors   Views   Downloads