Differentiable Subset Pruning of Transformer Heads
Jiaoda Li♣ Ryan Cotterell♣♠ Mrinmaya Sachan♣
♠University of Cambridge, UK
♣ETH Z¨urich, Switzerland
{jiaoda.li,ryan.cotterell,mrinmaya.sachan}@inf.ethz.ch
Abstract
Multi-head attention, a collection of several
attention mechanisms that independently at-
tend to different parts of the input, is the key
ingredient in the Transformer. Recent work
has shown, however, that a large proportion
of the heads in a Transformer’s multi-head at-
tention mechanism can be safely pruned away
without significantly harming the performance
of the model; such pruning leads to models
that are noticeably smaller and faster in prac-
tice. Our work introduces a new head pruning
technique that we term differentiable subset
pruning. Intuitively, our method learns per-
head importance variables and then enforces a
user-specified hard constraint on the number
of unpruned heads. The importance variables
are learned via stochastic gradient descent. We
conduct experiments on natural language in-
ference and machine translation; we show that
differentiable subset pruning performs com-
parably or better than previous works while
offering precise control of the sparsity level.1
1
Introduction
The Transformer (Vaswani et al., 2017) has be-
come one of the most popular neural architectures
used in NLP. Adaptations of the Transformer have
been applied to nearly every popular NLP task, for
example, parsing (Zhou and Zhao, 2019), machine
translation (Ng et al., 2019), question answering
(Yang et al., 2019) inter alia. Transformers also
form the backbone of state-of-the-art pre-trained
language models, for example, BERT (Devlin et
al., 2019), GPT-2 (Radford et al., 2019), and GPT-3
(Brown et al., 2020), that have further boosted per-
formance on various data-driven NLP problems.
The key ingredient in the Transformer architec-
ture is the multi-head attention mechanism, which
is an assembly of multiple attention functions
(Bahdanau et al., 2015) applied in parallel. In
1Our code is available here: https://github.com
/rycolab/differentiable-subset-pruning.
practice, each attention head works indepen-
dently, which allows the heads to capture different
kinds of linguistic phenomena (Clark et al., 2019;
Goldberg, 2019; Ettinger, 2020; Jawahar et al.,
2019). A natural question arises in this context:
How many heads does a transformer need?
Michel et al. (2019) offer the insight that a
large portion of the Transformer’s heads can be
pruned without significantly degrading the test
accuracy on the desired task. The experimental
evidence behind their claim is a simple greedy pro-
cedure that sequentially removes heads. This sug-
gests that a better pruner could reveal that a much
larger portion of the heads can be safely removed.
To provide a more robust answer to Michel et al.’s
question, we build a high-performance pruner and
show that their approach itself significantly under-
estimates the number of Transformer heads that
can be pruned away.
From a bird’s eye view, our paper contributes
the proposal that Transformer head pruning is best
viewed as a subset selection problem. Subset se-
lection is common across many areas of NLP,
from extractive summarization (Gillenwater et al.,
2012) to vowel typology (Cotterell and Eisner,
2017). In the case of head pruning, the concrete
idea is that the user specifies a number of heads K
that they would like their Transformer to have de-
pending on their budgetary and other constraints,
and then the pruner enforces this constraint. Meth-
odologically, we present a differentiable subset
pruner (Figure 1) that makes use of Gumbel ma-
chinery; specifically, the Gumbel top-K proce-
dure of Vieira (2014). This construction allows us
to relax our pruner into a differentiable sampling
routine that qualitatively resembles a discrete ana-
logue of dropout (Srivastava et al., 2014; Gal and
Ghahramani, 2016).
Empirically, we perform experiments on two
common NLP tasks: natural language inference
(MNLI; Williams et al., 2018) and machine trans-
lation (IWSLT2014; Cettolo et al., 2014). We
show that our differentiable subset pruning scheme
1442
Transactions of the Association for Computational Linguistics, vol. 9, pp. 1442–1459, 2021. https://doi.org/10.1162/tacl a 00436
Action Editor: Noah Smith. Submission batch: 5/2021; Revision batch: 7/2021; Published 12/2021.
c(cid:4) 2021 Association for Computational Linguistics. Distributed under a CC-BY 4.0 license.
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Figure 1: Illustration of gated multi-head attention compared with standard multi-head attention.
outperforms two recently proposed Transformer
head pruners—Michel et al. (2019) and Voita
et al. (2019)—on both tasks in terms of sparsity–
performance trade-off. Our method recovers a
pruned Transformer that has ≈ 80% accuracy on
MNLI and ≈ 30 BLEU score on IWSLT when
more than 90% of the heads are removed, which
brings about ≈ 33% inference speedup and ≈ 24%
model size shrinkage.2
Our experiments also suggest several broader
conclusions about pruning Transformers. In this
paper, we taxonomize existing pruning methods
into two pruning paradigms: pipelined pruning
and joint pruning. Pipelined pruning consists of
two stages: (i) training or fine-tuning an over-
parameterized model on the target task and (ii)
pruning the model after training. A number of tech-
niques fall into this category (LeCun et al., 1990;
Hassibi et al., 1994; Han et al., 2016; Molchanov
et al., 2017b). In contrast, joint pruning blends
the pruning objective into the training objective
by training or fine-tuning the over-parameterized
model with a sparsity-enforcing regularizer, some-
times followed up by a trivial post-processing step
to arrive at a final sparse model. Kingma et al.
(2015) and Louizos et al. (2018) are examples of
2See § 5.4.
this kind of pruning. We show that pipelined head
pruning schemes, such as that of Michel et al.,
underperform compared to joint head pruning
schemes, such as that of Voita et al. (2019). Our
differentiable subset pruner can be adapted to both
paradigms and it outperforms prior work in both,
especially in high sparsity regions.
2 Background: Multi-head Attention
In this section, we provide a detailed overview
of multi-head attention (Vaswani et al., 2017) in
order to develop the specific technical vocabulary
to discuss our approaches for head pruning. We
omit details about other parts of the Transformer
and refer the reader back to the original work of
Vaswani et al. (2017). First, let z = z1, . . . , zT
be a sequence of T real vectors where zt ∈ Rd,
and let q ∈ Rd be a query vector. An attention
mechanism is defined as
T(cid:2)
att(z, q) = Wo
αt(q)Wvzt
(1)
t=1
where
αt(q) = softmax
(cid:3)
q(cid:7)W (cid:7)
q Wkzt
√
d
(cid:4)
t
(2)
1443
The projection matrices Wo, Wv, Wq, Wk ∈ Rd×d
are learnable parameters. In self-attention, query
q comes from the same sequence z.
A Transformer is composed of L identical lay-
ers. In layer 1 ≤ l ≤ L, Hl different attention
mechanisms are applied in parallel; importantly,
it is this parallelism that has lead to the rise of the
Transformer—it is a more efficient architecture in
practice so it can be trained on more data. Each
individual attention mechanism is referred to as a
head; thus, multi-head attention is the simulta-
neous application of multiple attention heads in a
single architecture. In Vaswani et al. (2017), the
multiple heads are combined through summation:
Hl(cid:2)
mhattl(z, q) =
attlh(z, q)
(3)
h=1
where attlh is the hth attention head in the lth layer.
We also introduce a gate variable glh that takes
values in the interval [0, 1]:
gmhattl(z, q) =
Hl(cid:2)
h=1
glh · attlh(z, q)
(4)
Inserting glh into the multi-head attention en-
ables our pruning approach: setting the gate
variable to glh = 0 means the head attlh is
pruned away.
In the following sections, for the sake of no-
tational simplicity, we ignore the layer structure
of heads and label heads with a single index
L
h ∈ {1, . . . , H}, where H =
l=1 Hl is the total
number of heads in the unpruned model.
(cid:5)
3 Differentiable Subset Pruning
In this section, we propose a new head prun-
ing technique that we term differentiable subset
pruning. The key insight behind our approach
is that head pruning can be viewed as subset se-
lection. Concretely, our goal is to find a subset
of K heads (where K is a user-specified posi-
tive integer) that still allows the model to achieve
high performance. Many neural network pruners,
for example, Voita et al.’s (2019) proposed head
pruning technique, make it notably difficult to
pre-specify the number of pruned heads K 3. To
make our subset pruner differentiable, we apply
the Gumbel–softmax trick (Maddison et al., 2017)
3Later discussed in § 5.2.
and its extension to subset selection (Vieira, 2014;
Xie and Ermon, 2019). This gives us a pruning
scheme that always returns the specified number
of heads and can be applied in a pipelined or a
joint setting. In both cases, the differentiability is
necessary to learn the head weights.
3.1 Background: Gumbel-(soft)max
Let H = {1, . . . , H} be the set of Transformer
heads in a given architecture. Our goal is to return
a subset of head J ⊆ H where |J | = K for any
user-specified value of K. We use the notation
ιh > 0 to denote a head importance score of
the specific head h. The head importance score
intuitively corresponds to how much we would
like to have the head h in the subset of heads J .
We start our exposition by reviewing the Gum-
bel
trick in the context of selecting a single
head (K = 1) and then move onto discussing
its extension to subset selection. Given the head
importance scores ιh, suppose we would like to
sample a subset J of size 1 according to the
following distribution
p(J = {h}) =
ιh
Z
∝ ιh
(5)
(cid:5)
H
where Z =
h=1 ιh is the normalization con-
stant. The simplest way to achieve this to use
standard categorical sampling. However, as has
been noted by Maddison et al. (2014), categorical
sampling is not differentiable. Luckily, there is
a two-step process to massage categorical sam-
pling into a differentiable sampling procedure: (1)
reparameterize the categorical using Gumbels and
(2) soften the argmax into a softmax.
3.1.1 Step 1: Reparameterization
We can reparameterize categorical sampling using
the Gumbel-max trick (Gumbel, 1954) to first
separate the sampling from the parameter that we
wish to differentiate with respect to. The idea of
the Gumbel max trick is that categorical sampling
can be viewed as a perturb-and-max method. If we
first perturb the logits log(ιh) with Gumbel noise
nh ∼ Gumbel(0, 1) such that rh = log(ιh) + nh,
then sampling from a categorical is equivalent to
taking an argmax:
h∗ = argmax
h∈H
rh
(6)
Were argmax differentiable, we would be done;
unfortunately it is not.
1444
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
3.1.2 Step 2: Relaxing the argmax
Now to construct a fully differentiable procedure,
we replace the argmax with a softmax. The in-
tuition here is that the output of argmax may be
viewed as an one-hot vector with the one corre-
sponding to the index of the argmax.4 The in-
sight, then, is to relax the one-hot vector output by
the argmax into a softmax as follows:
Then we remove h∗
consideration and repeat the same procedure:
1 from the pool of heads under
rh
h∗
2 = argmax
h∈H\{h∗
}
1
…
h∗
K = argmax
1,…,h∗
h∈H\{h∗
K−1
rh
}
(10)
(11)
gh =
(cid:5)
exp(rh)
H
h(cid:14)=1 exp(rh(cid:14))
(7)
The probability of sampling these heads in this
order is given by the following expression:
This technique is called the Gumbel-softmax trick
(Jang et al., 2017), and the resulting distribution is
known as the Concrete distribution (Maddison
et al., 2017).5 It is often desirable to add an
additional annealing parameter τ > 0 to the
Gumbel-softmax:
gh =
(cid:5)
exp (rh/τ )
H
h(cid:14)=1 exp (rh(cid:14)/τ )
(8)
As the temperature tends to zero, that is, τ → 0,
the softmax turns into the argmax. Thus, through
the tunable τ , we can arbitrarily approximate the
argmax as a differentiable function.
3.2 Differentiable Subset Selection
The Gumbel trick can be generalized to cases
where we wish to sample an entire set of heads.
This is called the Gumbel-top-K trick. The idea is
that, rather than simply taking the max, we sort and
the take the top-K largest perturbed logits (Yellott,
1977; Vieira, 2014; Kool et al., 2019). One way to
think of the algorithm is that we are repeating the
Gumbel trick K times until we have the desired
number of heads. Following the exposition in
§ 3.1, we divide our discussion into two sections.
3.2.1 Step 1: Reparameterization
Similar to the top-1 case, we start by sampling the
first head using the perturb-and-max strategy:
h∗
1 = argmax
h∈H
rh
(9)
4More precisely, argmax returns a set. In our terminology,
it would return a multi-hot vector. We ignore this case in our
exposition for simplicity.
5Using the Gumbel-softmax results in a biased estimate
of the gradient. Subsequent work removed this bias (Tucker
et al., 2017).
p(h∗
1, . . . , h∗
K) =
ιh∗
1
Z
· · ·
Z −
ιh∗
(cid:5)
K
K−1
k=1 ιh∗
k
Thus, the probability of a set J is given by
p(J = {h∗
1, . . . ,h∗
=
K
})
(cid:2)
π∈Sk
p(h∗
π1
, . . . , h∗
πK
)
(12)
(13)
where SK is the set of all permutations of K
items. This is hard to compute as it involves a
sum over permutations. For a detailed discussion
on computing (13), we refer the reader to the
discussion in Vieira (2021a) and Vieira (2021b).
Ultimately, however, computing the exact proba-
bility of a subset of heads J is unnecessary for
this approach.
As an aside, we note that this procedure is equiv-
alent to a differentiable version of the classical
reservoir sampling algorithm (Vitter, 1985).
3.2.2 Step 2: Relaxing the argmax
The Gumbel-top-K trick can be relaxed similarly
to the top-1 case. This was first shown in detail by
Xie and Ermon (2019). Here, we provide a detailed
overview of the algorithm by analogy to the top-1
case. Similarly, the output of Gumbel-top-K can
be viewed as a K-hot vector, which is the sum of
the K one-hot vectors produced in (9)–(11). As
before, we begin by relaxing the one-hot vector of
the first head:
g(1)
h =
(cid:5)
exp(r(1)
h /τ )
h(cid:14)=1 exp(r(1)
H
h(cid:14) /τ )
(14)
This is a straight-forward analogue of the argmax
relaxation discussion in § 3.1.2. Next, we continue
1445
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
relaxing the successive argmaxes with successive
softmaxes (Pl¨otz and Roth, 2018) as follows:
g(2)
h =
g(K)
h =
exp(r(2)
H(cid:5)
h /τ )
exp(r(2)
h(cid:14) /τ )
h(cid:14)=1
…
exp(r(K)
h /τ )
H(cid:5)
exp(r(K)
h(cid:14) /τ )
h(cid:14)=1
where the r(k)
h are defined recursively
r(1)
h = rh
= r(k)
r(k+1)
h
h + log
(cid:6)
(cid:7)
1 − g(k)
h
(15)
(16)
(17)
(18)
Xie and Ermon (2019) argue that the above
recursion corresponds to a reasonable relaxation
of the Gumbel-top-K trick presented in § 3.2.1. To
understand the motivation behind the recursion in
(17), note that if g(k)
h = 1, which would happen
if the head has been sampled (i.e., no relaxation),
then that head would not be selected again as we
have r(k+1)
= −∞. As the scheme is a relaxation
h
of hard sampling, we will not have g(k)
h = 1 as long
as r(k)
is finite and τ > 0. Thus, the procedure
h
corresponds to something akin to a soft sampling.
Finally, we sum over all the relaxed one-hot
h in (14)–(16) to arrive at our softened
vectors g(k)
K-hot gate:
gh =
K(cid:2)
k=1
g(k)
h
(19)
It is (19) that we finally plug into the gated atten-
tion mechanism presented in (2).
3.3 Training the Subset Pruner
The differentiable subset pruning approach can be
applied in either a pipelined or a joint pruning set-
ting. (Please refer back to the last paragraph of § 1
for a discussion of the two different settings.) Our
approach is parameterized identically in both set-
tings, however. Specifically, we define head im-
portance score as follows:
ιh = exp(wh)
(20)
where wh is the hth component of a vector of
real-valued head weights w ∈ RH . In our setting,
the distinction between pipelined pruning and joint
pruning is relatively trivial. In the pipelined set-
ting, we learn the head importance weights w for
a model that has been trained on the task and leave
the model parameters untouched. On the other
hand, in the joint setting, we simultaneously learn
the head importance weights and the model pa-
rameters. In this regard, our differentiable subset
pruner much more closely resembles Voita et al.’s
(2019) method in that we learn head-specific im-
portance weights. On the other hand, Michel
et al.’s (2019) method makes use of an unlearned
gradient-based importance measure. In contrast
to Voita et al., however, our differentiable sub-
set pruner ensures that it returns a specific pre-
specified number of heads.
4 Experiments
4.1 Model and Data
We investigate two Transformer-based models in
the empirical portion of the paper.
BERT. BERT (Bidirectional Encoder Repre-
sentations from Transformers; Devlin et al., 2019)
is essentially a Transformer encoder. Since there
is no decoder part, BERT only has self-attention.
We focus on the base-uncased model with
12 layers and 12 heads in each layer (144 heads
in total). We use the implementation of Hugging
Face (Wolf et al., 2020). The model is pre-trained
on large text corpora using masked language mod-
eling (MLM) and next sentence prediction (NSP).
We fine-tune BERT on the Multi-Genre Natu-
ral Language Inference (MNLI; Williams et al.,
2018) corpus. The hyper-parameters are tuned on
the ‘‘matched’’ validation set, and accuracy is
reported on the ‘‘mismatched’’ validation set.
Enc–Dec. We implement a Transformer-based
encoder–decoder model with 6 encoder layers,
6 decoder layers and 6 heads in each layer (72
heads in total). The model has three types of atten-
tion heads: encoder self-attention, decoder self-
attention, and encoder–decoder cross attention.
We use the fairseq toolkit (Ott et al., 2019)
for our implementation. We train the model on
the International Workshop on Spoken Language
Translation (IWSLT2014; Cettolo et al., 2014)
German-to-English dataset. The hyper-parameters
are tuned on the validation set, and 4-gram BLEU
scores computed with multi-bleu.perl
(Koehn et al., 2007) are reported on the held-out
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
1446
test set. We use beam search with a beam size set
to 5 for decoding.
4.2 Baselines
We compare our approach to pruners in both the
pipelined and the joint paradigms. We refer to
the pipelined version of our differentiable subset
pruning as pipelined DSP and to the joint version
as joint DSP. Our specific points of comparison
are listed below.
4.2.1 Michel et al.
Michel et al. follow the pipelined pruning paradigm.
Concretely, given a dataset D = {(ym, xm)}M
m=1,
the importance of a head is estimated with a
gradient-based proxy score (Molchanov et al.,
2017b):
ιh =
(cid:8)
(cid:8)
(cid:8)
(cid:8)
M(cid:2)
m=1
1
M
∂L(ym, xm)
∂gh
(cid:8)
(cid:8)
(cid:8)
(cid:8) ≥ 0
(21)
where L is the task-specific loss function. Then,
all the heads in the model are sorted accordingly
and removed one by one in a greedy fashion. The
importance scores are re-computed every time a
certain number of heads are removed.
4.2.2 Voita et al.
In the fashion of joint pruning, Voita et al. apply
a stochastic approximation to L0 regularization
(Louizos et al., 2018) to the gates to encourage
the model to prune less important heads. The gate
variables are sampled from a binary Hard Concrete
distribution (Louizos et al., 2018) independently,
parameterized by φh. The L0 norm was relaxed
into the sum of probability mass of gates being
non-zero:
LC(φ) =
H(cid:2)
h=1
(1 − P (gh = 0|φh))
(22)
which was then added to the task-specific loss L:
R(θ, φ) = L(θ, φ) + λLC(φ)
(23)
where θ are the parameters of the original model,
and λ is the weighting coefficient for the regular-
ization, which we can use to indirectly control the
number of heads to be kept.
4.2.3 Straight-Through Estimator (STE)
In this baseline, the Gumbel soft top-K in joint
DSP is replaced with hard top-K, while the hard
top-K function is back-propagated through as
if it had been the identity function, which is also
termed as straight-through estimator (Bengio et al.,
2013).
4.2.4 Unpruned Model
The model is trained or fine-tuned without any
sparsity-enforcing regularizer and no post-hoc
pruning procedure is performed. We take this com-
parison to be an upper bound on the performance
of any pruning technique.
4.3 Experimental Setup
Pipelined Pruning. For the two pipelined prun-
ing schemes, the model is trained or fine-tuned on
the target task (3 epochs for BERT and 60 epochs
for Enc–Dec) before being pruned. We learn the
head importance weights for pipelined DSP for
one additional epoch in order to have an apples-
to-apples comparison with Michel et al. in terms
of compute (number of gradients computed).
Joint Pruning. The model is trained or fine-
tuned for the same number of epochs as pipelined
pruning while sparsity-enforcing regularization is
applied. We found it hard to tune the weighting
coefficient λ for Voita et al. to reach the desired
sparsity (see § 5.2 and Figure 3). For the ease of
comparison with other approaches, we adjust the
number of unpruned heads to the targeted number
by re-including heads with the highest gate values
from the discarded ones, or excluding those with
the smallest gate values in the kept ones. We make
sure the adjustments are as small as possible.
Annealing Schedule.
In our experiments, we
choose a simple annealing schedule for DSP where
the temperature τ cools down in a log-linear scale
within a predefined number of steps Ncooldown from
an initial temperature τini and then stays at the final
temperature τend for the rest of the training steps:
log τ = log τini −
(cid:9)
n
Ncooldown
min
(cid:10)
(cid:6)
(24)
(cid:7)
, 1
·
log τini − log τend
where n is the number of training steps that has
been run. We report the set of hyperparameters
used in our experiments in Appendix A.
4.4 Results
The test performance under various sparsity lev-
els obtained by multiple pruning methods are
1447
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Figure 2: A comparison of various pruning methods.
presented in Figure 2a, Figure 2b, and Appendix C.
We also zoom in to results when more than
two-thirds of the heads are pruned in Figure 2c
and Figure 2d, where the differences between the
various methods are most evident.
5 Discussion
5.1 Pipelined Pruning
IWSLT dataset, when only 24 heads are left un-
pruned, the Enc–Dec pruned with Michel et al.
cannot produce meaningful outputs (≈ 0 BLEU
score), while pipelined DSP achieves higher than
20 BLEU. The results indicate that the learned
head importance scores are more useful for prun-
ing than those computed with gradient-based
measures.
We first compare the two pipelined pruning meth-
ods: Michel et al. (2019) and pipelined DSP. As
shown in Figure 2, pipelined DSP outperforms
Michel et al. by a large margin. For example, on
the MNLI task, when there are 24 heads left in the
model, pipelined DSP keeps an accuracy above
70%, but Michel et al. drops below 50%. On the
5.2 Joint Pruning
We then compare the three joint pruning methods:
Voita et al. (2019), STE, and joint DSP. Impres-
sively, joint DSP is able to prune up to 91.6%
(12 heads left) and 94.4% (4 heads left) of heads
in BERT and the Enc–Dec, respectively, without
causing much degradation in test performance
1448
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Figure 3: Number of unpruned heads as a function
of L0 regularization coefficient λ for Voita et al.
Figure 4: Inference speedup (%) and model size
shrinkage (%) of pruned BERT model on the
MNLI-mismatched validation set as a function of
remaining heads.
(5.5% drop in accuracy for MNLI and 4.22 drop
in BLEU score for IWSLT). Voita et al. and
STE are neck and neck with joint DSP when
the model is lightly pruned, but joint DSP gains
the upper hand when less than 1/6 of the heads
are left unpruned.
In addition, with Voita et al.’s method, it is much
harder to enforce a hard constraint on the number
of unpruned heads. This difficulty is intrinsic to
their method as Voita et al.’s method relies on the
regularization coefficient λ to indirectly control
the sparsity. In practice, our experiments indicate
that λ is hard to tune and there are certain levels
of sparsity that cannot be reached. The difficulty
in tuning λ is shown in Figure 3; we see that
the number of unpruned heads does not decrease
monotonically as λ increases; on the contrary, it
often fluctuates. There also appears to be an upper
bound (117) on the number of heads that can be
kept no matter how small λ is. More importantly,
a small increase in λ can sometimes drastically
reduce the number of heads. For instance, when λ
is increased from 0.0009 to 0.0014, the number of
heads reduced quickly from 30 to 11. Therefore,
we conclude that Voita et al.’s method is inade-
quate if the user requires a pre-specified number
of Transformer heads. In contrast, DSP (as well as
STE), our proposal, enables us to directly specify
the number of heads we want to keep in accordance
with our computation budget.
5.3 Pipelined Pruning vs Joint Pruning
Lastly, we offer a philosophical comparison of the
two pruning paradigms. It is clear from Figure 2
that the joint pruning methods are superior to pipe-
lined pruning methods for both tasks, as models
sparsified with the joint pruning schemes (joint
DSP, STE and Voita et al.) perform better than
those pruned with pipelined schemes (pipelined
DSP and Michel et al.) under almost every sparsity
level. This suggests that joint training is more ef-
fective in finding sparse subnetworks than pipe-
lined pruning. Moreover, joint pruning is also
more computationally efficient. In addition to the
same number of epochs required by both para-
digms for training/fine-tuning, pipelined pruning
requires us to learn or estimate gradient-based
head importance scores for one extra epoch. Even
though joint pruning methods train H more param-
eters during training/fine-tuning, H is typically
orders of magnitudes smaller than the total num-
ber of model parameters, so the additional com-
putational overhead is negligible.
5.4 Inference Efficiency
In this section, we obtain the pruned model by
actually removing the heads with mask values
0. Empirically, we observe substantial wallclock
improvements in our pruned models compared
to unpruned models. In practice, we found that
the inference efficiency improves monotonically
as the number of unpruned heads decrease and
is not significantly impacted by the distribution
of heads across layers. Taking BERT on MNLI-
mismatched validation set (batch size of 8) as an
example, we randomly sample 10 head masks for
1449
Figure 5: Inference speedup (%) and model size shrinkage (%) of the various pruned BERT models vs.
accuracy (%) on the MNLI-mismatched validation set.
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
Figure 6: Distribution of unpruned heads across layers. Darkness of the color increases monotonically
with the number of heads.
each sparsity level, measure their inference speed-
up and model size shrinkage compared to the un-
pruned model, and report the average in Figure 4.
In general, head pruning does lead to a faster and
smaller model, and the more we prune, the faster
and smaller the model becomes.
Comparison of various pruning schemes is dis-
played in Figure 5. If we set a threshold for
accuracy (e.g., 80%), joint DSP returns a model
with a ≈ 33% speedup in execution time and
≈ 24% decrease in model size.
5.5 Distribution of Heads
We visualize the distribution of unpruned heads
across different layers in Figure 6. For BERT
(Figure 6a), we observe that the top layers (10–12)
are the first to be pruned and the heads in the
middle layers (3–7) are mostly retained. This ob-
servation is in conformity with Prasanna et al.
(2020) and Sajjad et al. (2021). Budhraja et al.
(2020) also highlight the importance of middle
layers but finds no preference between top and
bottom layers. For Enc–Dec (Figure 6b), we find
that a lot more encoder–decoder cross attention
heads are retained compared to the other two types
of attentions (encoder and decoder self attentions).
The encoder self-attention heads are completely
pruned away when less than 16 heads are left,
which again conforms with the observations of
Michel et al. (2019) and Voita et al. (2019).
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
1450
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Figure 7: Training dynamics of joint DSP on BERT (K = 12). The lower x-axis shows the number
of training steps, and the upper x-axis shows the corresponding temperature in logarithm scale. Left
y-axis (orange) shows test accuracy on MNLI-mismatched validation set. Right y-axis (purple) shows
the percentage of heads selected at current step that are kept eventually.
Methods
Computation Overhead
Sparsity Controllability Test Performance
Michel et al.
Pipelined DSP (this paper)
Voita et al.
STE (this paper)
Joint DSP (this paper)
(cid:2)
(cid:2)
(cid:3)
(cid:3)
(cid:3)
(cid:3)
(cid:3)
(cid:2)
(cid:3)
(cid:3)
(cid:2)
(cid:2)
(cid:3)
(cid:3)
(cid:3)
Table 1: Qualitative comparison of different pruning methods.
5.6 Analysis of Training Dynamics
To better understand our joint DSP approach, we
inspect its behavior during training. We plot the
intermediate accuracy of BERT during training
when joint DSP (K = 12) is applied in Figure 7a
(in orange). We also compute the percentage of
heads selected at the current step that are even-
tually kept in the end (in purple). We observe
the selected subset of heads is no longer updated
after 14000 training steps (purple line stays at
100%). Therefore, the joint pruning process may
be viewed as having two distinct phases—(i) head
selection and (ii) fine-tuning. This piques one’s
interest as it appears to superficially resemble a re-
versed pipelined pruning. During head selection,
the subset of heads to be kept is determined and the
model is adapted to the specified level of sparse-
ness. During fine-tuning, the selected subnetwork
is fine-tuned so that the testing accuracy improves
steadily. Our experiments indicate that annealing
is essential for training a high-performance pruner:
It allows the model to gradually settle down on
one particular subset of heads, whereas without
annealing the pruner never converges to a fixed set
and thereby does not enter the fine-tuning phase.
See Figure 7b for a visualization.6
5.7 Summary
The five pruning methods discussed in this paper
are summarized in Table 1. Joint DSP is able to
maintain the highest test performance while con-
suming similar computational resources to Voita
et al. and offering fine-grained control over the
number of unpruned heads like Michel et al. It is
worth noting that STE shares the same benefits
of low computational overhead and exact sparsity
control as joint DSP, despite being slightly inferior
in performance. It also has fewer hyperparameters
6We analyze other sparsity levels as well and observe
similar behaviors. Two examples are shown in Appendix B.
1451
to tune and hence is easier to implement. There-
fore, we believe STE could be favorable when test
performance is not that critical.
6 Related Work
Unstructured Pruning. Neural network prun-
ing has been studied for decades. Early work in-
cludes optimal brain damage (LeCun et al., 1990)
and optimal brain surgeon (Hassibi et al., 1994),
which approximate the loss function of a trained
model with a second-order Taylor expansion and
remove certain parameters in the network while
minimizing impact on loss. Recent years have seen
a resurgence in this approach (Molchanov et al.,
2017b; Theis et al., 2018; Michel et al., 2019).
More recently, magnitude pruning that discards
parameters with small absolute values has gained
much popularity (Han et al., 2015, 2016; Guo et al.,
2016; Zhu and Gupta, 2018). Gordon et al. (2020)
apply magnitude pruning to BERT and shows that
the model has similar prunability and transferabil-
ity whether pruned after pre-training or after fine-
tuning. Related to magnitude based pruning is
movement pruning introduced by Sanh et al.
(2020) which considers changes in weights instead
of magnitudes for pruning.
Structured Pruning. Different
from above-
mentioned unstructured pruning methods that
prune individual parameters, structured pruning
methods prune at a higher level, such as convo-
lutional channels, attention heads, or even layers.
Structured pruning almost always leads to a de-
crease in model size and inference cost, while un-
structured pruning often results in sparse matrices,
which cannot be utilized without dedicated hard-
ware or libraries (Han et al., 2016). Previously,
structured pruning had primarily been applied to
convolutional neural networks (Wen et al., 2016;
Li et al., 2017; Luo et al., 2017; He et al., 2017;
Liu et al., 2017; Huang and Wang, 2018), but it
has recently been applied to NLP, in the form
of layer pruning (Fan et al., 2020; Sajjad et al.,
2021) and head pruning (Michel et al., 2019;
Voita et al., 2019; McCarley et al., 2021) of
Transformer-based models. Apart from compres-
sion and speedup, head pruning is also helpful
for model analysis; Voita et al. (2019) finds that
the heads that survive pruning play consistent and
linguistically-interpretable roles. Prasanna et al.
(2020) discovered the heads that are pruned last
tend to be in the earlier and middle layers.
Dropout for Pruning. A variety of regularizers
have been used to sparsify neural networks. For
example, Han et al. (2015) apply L1 regularization,
and Louizos et al. (2018) apply L0 regularization.
Dropout, as one of the regularization methods, has
also been demonstrated to be effective for con-
verting a model to be robust to pruning. It was
discovered that dropout encourages sparsity when
dropout was proposed (Srivastava et al., 2014).
Recently, the assumption that the model trained
with dropout tend to be more robust to post-hoc
pruning was also explored. LayerDrop (Fan et al.,
2020) randomly drops entire layers in Transformer
with a fixed dropout rate during training and
simply keeps every other layer during inference.
Targeted Dropout (Gomez et al., 2019) ranks
units in the order of magnitude and only applies
dropout to those with small magnitudes and per-
forms magnitude pruning afterwards. Molchanov
et al. (2017a) introduce variational dropout, which
allows learning a different dropout rate for each
unit. Kingma et al. (2015) extend it for pruning
by keeping only the units with lower dropout rate
for test. Our approach is in the same vein but
distinct as we learn importance variables rather
than dropout rate and the number of heads to be
dropped is specified explicitly, which allows us a
control over sparsity.
Lottery Ticket Hypothesis. Frankle and Carbin
(2019) propose the Lottery Ticket Hypothesis that
there exist subnetworks (‘‘winning lottery tick-
ets’’) in a over-parameterized model, which can
be trained in isolation to reach comparable test
performance as the original network in a similar
number of iterations. It shows such tickets can be
discovered through magnitude pruning. Brix et al.
(2020) successfully apply the hypothesis to the
Transformer. Prasanna et al. (2020) and Behnke
and Heafield (2020) demonstrate head pruning
may also be used to select a winning subnetwork.
7 Conclusion
We propose differentiable subset pruning, a novel
method for sparsifying Transformers. The method
allows the user to directly specify the desired
sparsity level, and it achieves a better sparsity–
accuracy trade-off compared to previous work,
leading to a faster and more efficient model after
pruning. It demonstrates improvements over ex-
isting methods for pruning two different models
(BERT and Enc–Dec) on two different tasks
1452
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
(textual entailment and machine translation), re-
spectively. It can be applied in both pruning para-
digms (pipelined and joint pruning). Although we
study head pruning in the paper, our approach can
be extended to other structured and unstructured
pruning scenarios. In future work, it would be in-
teresting to look into such cases.
Acknowledgments
We would like to thank the action editor Noah
Smith and the anonymous reviewers for their
helpful comments. MS acknowledges funding by
SNF under project #201009.
References
Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua
Bengio. 2015. Neural machine translation by
In
jointly learning to align and translate.
3rd International Conference on Learning
Representations.
Maximiliana Behnke and Kenneth Heafield. 2020.
Losing heads in the lottery: Pruning trans-
former attention in neural machine translation.
In Proceedings of
the 2020 Conference on
Empirical Methods in Natural Language Pro-
cessing (EMNLP), pages 2664–2674, Online.
Association for Computational Linguistics.
Yoshua Bengio, Nicholas L´eonard, and Aaron
C. Courville. 2013. Estimating or propagating
gradients through stochastic neurons for condi-
tional computation. CoRR, abs/1308.3432v1.
Christopher Brix, Parnia Bahar, and Hermann
Ney. 2020. Successfully applying the stabilized
lottery ticket hypothesis to the transformer ar-
chitecture. In Proceedings of the 58th Annual
Meeting of the Association for Computational
Linguistics, pages 3909–3915, Online. Associ-
ation for Computational Linguistics.
Tom Brown, Benjamin Mann, Nick Ryder,
Melanie Subbiah, Jared D. Kaplan, Prafulla
Dhariwal, Arvind Neelakantan, Pranav Shyam,
Girish Sastry, Amanda Askell, Sandhini
Agarwal, Ariel Herbert-Voss, Gretchen
Krueger, Tom Henighan, Rewon Child, Aditya
Ramesh, Daniel Ziegler, Jeffrey Wu, Clemens
Winter, Chris Hesse, Mark Chen, Eric
Sigler, Mateusz Litwin, Scott Gray, Benjamin
Chess, Jack Clark, Christopher Berner, Sam
McCandlish, Alec Radford, Ilya Sutskever, and
Dario Amodei. 2020. Language models are
few-shot learners. In Advances in Neural In-
formation Processing Systems, volume 33,
pages 1877–1901.
Aakriti Budhraja, Madhura Pande, Preksha Nema,
Pratyush Kumar, and Mitesh M. Khapra. 2020.
On the weak link between importance and prun-
ability of attention heads. In Proceedings of
the 2020 Conference on Empirical Methods
in Natural Language Processing (EMNLP),
pages 3230–3235, Online. Association for
Computational Linguistics.
Mauro Cettolo, Jan Niehues, Sebastian St¨uker,
Luisa Bentivogli, and Marcello Federico. 2014.
Report on the 11th IWSLT evaluation cam-
paign. In Proceedings of
the International
Workshop on Spoken Language Translation,
Hanoi, Vietnam, volume 57.
Kevin Clark, Urvashi Khandelwal, Omer Levy,
and Christopher D. Manning. 2019. What does
BERT look at? An analysis of BERT’s atten-
tion. In Proceedings of the 2019 ACL Workshop
BlackboxNLP: Analyzing and Interpreting Neu-
ral Networks for NLP, pages 276–286.
Ryan Cotterell and Jason Eisner. 2017. Probabilis-
tic typology: Deep generative models of vowel
inventories. In Proceedings of the 55th An-
nual Meeting of the Association for Compu-
tational Linguistics (Volume 1: Long Papers),
pages 1182–1192, Vancouver, Canada. Associ-
ation for Computational Linguistics.
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and
Kristina Toutanova. 2019. BERT: Pre-training
of deep bidirectional transformers for language
understanding. In Proceedings of the 2019 Con-
ference of
the North American Chapter of
the Association for Computational Linguistics:
Human Language Technologies, Volume 1
(Long and Short Papers), pages 4171–4186,
Minneapolis, Minnesota. Association for Com-
putational Linguistics.
Allyson Ettinger. 2020. What BERT is not:
Lessons from a new suite of psycholinguistic di-
agnostics for language models. Transactions of
the Association for Computational Linguistics,
8:34–48.
1453
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Angela Fan, Edouard Grave, and Armand Joulin.
2020. Reducing transformer depth on demand
with structured dropout. In 8th International
Conference on Learning Representations.
inference engine on compressed deep neural
network. In 2016 ACM/IEEE 43rd Annual
International Symposium on Computer Archi-
tecture (ISCA), pages 243–254.
Jonathan Frankle and Michael Carbin. 2019. The
lottery ticket hypothesis: Finding sparse, train-
able neural networks. In 7th International Con-
ference on Learning Representations.
Yarin Gal and Zoubin Ghahramani. 2016. Drop-
out as a Bayesian approximation: Representing
model uncertainty in deep learning. In Proceed-
ings of The 33rd International Conference on
Machine Learning, volume 48 of Proceedings
of Machine Learning Research, pages 1050–1059,
New York, New York, USA. PMLR.
Jennifer Gillenwater, Alex Kulesza, and Ben
Taskar. 2012. Discovering diverse and salient
threads in document collections. In Proceed-
ings of the 2012 Joint Conference on Empirical
Methods in Natural Language Processing and
Computational Natural Language Learning,
pages 710–720, Jeju Island, Korea. Association
for Computational Linguistics.
Yoav Goldberg. 2019. Assessing BERT’s syntac-
tic abilities. CoRR, abs/1901.05287v1.
Aidan N. Gomez, Ivan Zhang, Kevin Swersky,
Yarin Gal, and Geoffrey E. Hinton. 2019.
Learning sparse networks using targeted drop-
out. CoRR, abs/1905.13678v5.
Mitchell Gordon, Kevin Duh, and Nicholas
Andrews. 2020. Compressing BERT: Study-
ing the effects of weight pruning on transfer
learning. In Proceedings of
the 5th Work-
shop on Representation Learning for NLP,
pages 143–155, Online. Association for Com-
putational Linguistics.
Emil Julius Gumbel. 1954. Statistical theory of
extreme values and some practical applications.
Journal of
the Royal Aeronautical Society,
58(527):792–793.
Yiwen Guo, Anbang Yao, and Yurong Chen.
2016. Dynamic network surgery for efficient
DNNs. In Advances in Neural Information Pro-
cessing Systems, volume 29.
S. Han, X. Liu, H. Mao, J. Pu, A. Pedram, M. A.
Horowitz, and W. J. Dally. 2016. EIE: Efficient
Song Han, Huizi Mao, and William J. Dally.
2016. Deep compression: Compressing deep
neural network with pruning, trained quantiza-
tion and huffman coding. In 4th International
Conference on Learning Representations.
Song Han, Jeff Pool, John Tran, and William
Dally. 2015. Learning both weights and connec-
tions for efficient neural network. In Advances
in Neural Information Processing Systems,
volume 28.
Babak Hassibi, David Stork, and Gregory Wolff.
1994. Optimal brain surgeon: Extensions and
performance comparisons. n Advances in Neu-
ral Information Processing Systems, volume 6.
Y. He, X. Zhang, and J. Sun. 2017. Channel
pruning for accelerating very deep neural net-
works. In 2017 IEEE International Conference
on Computer Vision (ICCV), pages 1398–1406.
Zehao Huang and Naiyan Wang. 2018. Data-
driven sparse structure selection for deep neu-
ral networks. In Proceedings of the European
Conference on Computer Vision (ECCV).
Eric Jang, Shixiang Gu, and Ben Poole. 2017.
Categorical reparameterization with Gumbel-
softmax. In 5th International Conference on
Learning Representations.
Ganesh Jawahar, Benoˆıt Sagot, and Djam´e
Seddah. 2019. What does BERT learn about
the structure of language? In Proceedings of
the 57th Annual Meeting of the Association for
Computational Linguistics, pages 3651–3657,
Florence, Italy. Association for Computational
Linguistics.
Durk P. Kingma, Tim Salimans, and Max Welling.
2015. Variational dropout and the local repa-
rameterization trick. In Advances in Neural In-
formation Processing Systems, volume 28.
Philipp Koehn, Hieu Hoang, Alexandra Birch,
Chris Callison-Burch, Marcello Federico,
Nicola Bertoldi, Brooke Cowan, Wade Shen,
Christine Moran, Richard Zens, Chris Dyer,
Ondˇrej Bojar, Alexandra Constantin, and Evan
1454
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Herbst. 2007. Moses: Open source toolkit for
statistical machine translation. In Proceedings
of the 45th Annual Meeting of the Association
for Computational Linguistics Companion Vol-
ume Proceedings of the Demo and Poster Ses-
sions, pages 177–180, Prague, Czech Republic.
Association for Computational Linguistics.
Wouter Kool, Herke Van Hoof, and Max Welling.
2019. Stochastic beams and where to find them:
The Gumbel-top-k trick for sampling sequences
without replacement. In Proceedings of the 36th
International Conference on Machine Learn-
ing, volume 97 of Proceedings of Machine
Learning Research, pages 3499–3508. PMLR.
Yann LeCun, John Denker, and Sara Solla. 1990.
Optimal brain damage. In Advances in Neural
Information Processing Systems, volume 2.
Hao Li, Asim Kadav, Igor Durdanovic, Hanan
Samet, and Hans Peter Graf. 2017. Pruning fil-
ters for efficient convnets. In 5th International
Conference on Learning Representations.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao
Huang, Shoumeng Yan, and Changshui Zhang.
2017. Learning efficient convolutional net-
works through network slimming. In Proceed-
ings of the IEEE International Conference on
Computer Vision, pages 2736–2744.
Christos Louizos, Max Welling, and Diederik P.
Kingma. 2018. Learning sparse neural networks
through L0 regularization. In 6th International
Conference on Learning Representations.
Jian-Hao Luo, Jianxin Wu, and Weiyao Lin.
2017. Thinet: A filter level pruning method for
deep neural network compression. In Proceed-
ings of the IEEE International Conference on
Computer Vision (ICCV).
Chris J. Maddison, Andriy Mnih, and Yee Whye
Teh. 2017. The concrete distribution: A con-
tinuous relaxation of discrete random variables.
In 5th International Conference on Learning
Representations.
Chris J. Maddison, Daniel Tarlow, and Tom
Minka. 2014. A∗ sampling. In Advances in Neu-
ral Information Processing Systems, volume 27.
question answering model. CoRR, abs/1910
.06360v3.
Paul Michel, Omer Levy, and Graham Neubig.
2019. Are sixteen heads really better than one?
In Advances in Neural Information Processing
Systems, volume 32.
Dmitry Molchanov, Arsenii Ashukha, and Dmitry
Vetrov. 2017a. Variational dropout sparsifies
deep neural networks. In Proceedings of the
34th International Conference on Machine
Learning, volume 70 of Proceedings of Ma-
chine Learning Research, pages 2498–2507,
PMLR.
Pavlo Molchanov, Stephen Tyree, Tero Karras,
Timo Aila, and Jan Kautz. 2017b. Pruning con-
volutional neural networks for resource effi-
cient inference. In 5th International Conference
on Learning Representations.
Nathan Ng, Kyra Yee, Alexei Baevski, Myle Ott,
Michael Auli, and Sergey Edunov. 2019. Face-
book FAIR’s WMT19 news translation task
submission. In Proceedings of the Fourth Con-
ference on Machine Translation (Volume 2:
Shared Task Papers, Day 1), pages 314–319,
Florence, Italy. Association for Computational
Linguistics.
Myle Ott, Sergey Edunov, Alexei Baevski, Angela
Fan, Sam Gross, Nathan Ng, David Grangier,
and Michael Auli. 2019. fairseq: A fast, ex-
tensible toolkit
In
Proceedings of the 2019 Conference of the
North American Chapter of the Association for
Computational Linguistics (Demonstrations),
pages 48–53, Minneapolis, Minnesota. Asso-
ciation for Computational Linguistics.
for sequence modeling.
Tobias Pl¨otz and Stefan Roth. 2018. Neural near-
est neighbors networks. In Advances in Neural
Information Processing Systems, volume 31.
Sai Prasanna, Anna Rogers, and Anna Rumshisky.
2020. When BERT plays the lottery, all tickets
are winning. In Proceedings of the 2020 Confer-
ence on Empirical Methods in Natural Language
Processing (EMNLP), pages 3208–3229, Online.
Association for Computational Linguistics.
J. S. McCarley, Rishav Chakravarti, and Avirup
Sil. 2021. Structured pruning of a BERT-based
Alec Radford, Jeffrey Wu, Rewon Child, David
Luan, Dario Amodei, and Ilya Sutskever. 2019.
1455
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Language models are unsupervised multitask
learners.
Hassan Sajjad, Fahim Dalvi, Nadir Durrani, and
Preslav Nakov. 2021. On the effect of dropping
layers of pre-trained transformer models. CoRR,
abs/2004.03844v2.
Victor Sanh, Thomas Wolf, and Alexander Rush.
2020. Movement pruning: Adaptive sparsity
by fine-tuning. In Advances in Neural In-
formation Processing Systems, volume 33,
pages 20378–20389.
Nitish Srivastava, Geoffrey Hinton, Alex
Krizhevsky,
and Ruslan
Ilya Sutskever,
Salakhutdinov. 2014. Dropout: A simple way
from overfit-
to prevent neural networks
ting. Journal of Machine Learning Research,
15:1929–1958.
Lucas Theis, Iryna Korshunova, Alykhan Tejani,
and Ferenc Husz´ar. 2018. Faster gaze prediction
with dense networks and Fisher pruning. CoRR,
abs/1801.05787v2.
George Tucker, Andriy Mnih, Chris J Maddison,
John Lawson, and Jascha Sohl-Dickstein. 2017.
REBAR: Low-variance, unbiased gradient esti-
mates for discrete latent variable models. In
Advances in Neural Information Processing
Systems, volume 30.
Ashish Vaswani, Noam Shazeer, Niki Parmar,
Jakob Uszkoreit, Llion Jones, Aidan N. Gomez,
Łukasz Kaiser, and Illia Polosukhin. 2017. At-
tention is all you need. In Advances in Neural
Information Processing Systems, volume 30.
Tim Vieira. 2014. Gumbel-max trick and weighted
reservoir sampling.
Tim Vieira. 2021a. On the distribution function of
order statistics.
Tim Vieira. 2021b. On the distribution of the
smallest indices.
Jeffrey S. Vitter. 1985. Random sampling with a
reservoir. ACM Transactions on Mathematical
Software, 11(1):37–57.
Elena Voita, David Talbot, Fedor Moiseev, Rico
Sennrich, and Ivan Titov. 2019. Analyzing
multi-head self-attention: Specialized heads do
the heavy lifting, the rest can be pruned. In
Proceedings of the 57th Annual Meeting of
the Association for Computational Linguistics,
pages 5797–5808, Florence, Italy. Association
for Computational Linguistics.
Wei Wen, Chunpeng Wu, Yandan Wang, Yiran
Chen, and Hai Li. 2016. Learning structured
sparsity in deep neural networks. In Advances
in Neural Information Processing Systems,
volume 29.
Adina Williams, Nikita Nangia, and Samuel
Bowman. 2018. A broad-coverage challenge
corpus for sentence understanding through in-
ference. In Proceedings of the 2018 Confer-
ence of the North American Chapter of the
Association for Computational Linguistics: Hu-
man Language Technologies, Volume 1 (Long
Papers), pages 1112–1122. Association for
Computational Linguistics.
Thomas Wolf, Lysandre Debut, Victor Sanh,
Julien Chaumond, Clement Delangue, Anthony
Moi, Pierric Cistac, Tim Rault, R´emi Louf,
Morgan Funtowicz, Joe Davison, Sam Shleifer,
Patrick von Platen, Clara Ma, Yacine Jernite,
Julien Plu, Canwen Xu, Teven Le Scao, Sylvain
Gugger, Mariama Drame, Quentin Lhoest,
and Alexander M. Rush. 2020. Transformers:
State-of-the-art natural language processing. In
Proceedings of the 2020 Conference on Empir-
ical Methods in Natural Language Processing:
System Demonstrations, pages 38–45, Online.
Association for Computational Linguistics.
Sang Michael Xie and Stefano Ermon. 2019.
Reparameterizable subset sampling via contin-
uous relaxations. In International Joint Confer-
ence on Artificial Intelligence.
Zhilin Yang, Zihang Dai, Yiming Yang, Jaime
Carbonell, Ruslan Salakhutdinov, and Quoc V.
Le. 2019. XLNet: Generalized autoregressive
pretraining for
In
Advances in Neural Information Processing
Systems, volume 32.
language understanding.
John I. Yellott. 1977. The relationship between
Luce’s Choice Axiom, Thurstone’s Theory of
Comparative Judgment, and the double expo-
nential distribution. Journal of Mathematical
Psychology, 15(2):109–144.
1456
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Junru Zhou and Hai Zhao. 2019. Head-Driven
Phrase Structure Grammar parsing on Penn
Treebank. In Proceedings of the 57th Annual
Meeting of the Association for Computational
Linguistics, pages 2396–2408, Florence, Italy.
Association for Computational Linguistics.
Michael Zhu and Suyog Gupta. 2018. To prune, or
not to prune: Exploring the efficacy of pruning
for model compression. In 6th International
Conference on Learning Representations.
A Experimental Setup
We report the hyperparameters for joint DSP we
use in our experiments in Table 2, which are ob-
tained by tuning on the validation set.
τini
τend
Ncooldown
lr for wh
BERT
1000
1e − 08
25000
0.5
Enc–Dec
0.1
1e − 08
15000
0.2
Table 2: Hyperparameters used for joint DSP.
Figure 8a, we observe the same two-phase train-
ing behavior as K = 12. The selected subset of
heads is not altered anymore after 16000 steps. In
Figure 8c, unlike the cases where there are very
few heads, the head masks are constantly updated
throughout the training procedure. Yet a large por-
tion (91.7%) of the heads remain unchanged after
17000 steps. Its two-phase behavior is still appar-
ent in comparison with training without annealing
(Figure 8d).
B Analysis of Training Dynamics
C Detailed Results
We present two more examples where heads are
scarce (K = 8) or redundant (K = 108). In
The detailed results for plotting Figure 2 are pre-
sented in Table 3.
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
1457
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3
Figure 8: Training dynamics of joint DSP on BERT. The lower x-axis shows the number of training
steps, and the upper x-axis shows the corresponding temperature in logarithm scale. Left y-axis (orange)
shows test accuracy on MNLI-mismatched validation set. Right y-axis (purple) shows the percentage
of heads selected at current step that are kept eventually.
1458
Unpruned Heads Michel et al.
Pipelined DSP Voita et al.
STE
Joint DSP
132
120
108
96
84
72
60
48
36
24
12
11
10
9
8
7
6
5
4
3
2
1
84.38
84.60
84.19
84.24
83.50
82.47
81.74
79.26
70.82
47.54
40.59
40.16
39.71
40.88
36.16
36.13
34.28
33.24
33.49
32.68
32.74
34.28
84.15
84.41
82.64
83.27
83.37
82.95
79.69
79.10
76.08
70.72
56.29
50.81
49.14
51.20
45.74
43.11
40.90
41.95
42.64
41.79
38.30
43.28
84.26
84.18
84.39
84.42
84.00
83.93
83.37
83.24
81.68
81.02
76.91
76.30
75.34
76.12
74.12
74.14
74.18
73.89
73.12
62.84
62.87
62.09
84.77
84.59
84.52
84.68
84.20
84.08
83.85
82.81
82.20
81.44
73.79
78.91
77.10
76.99
69.29
69.64
70.45
66.53
65.43
65.15
57.07
61.79
84.70
84.97
83.95
84.41
84.02
83.48
83.21
83.22
82.51
81.54
79.74
79.02
78.35
77.51
77.57
76.32
76.70
76.17
75.06
73.36
72.14
61.79
(a) Accuracy on the MNLI-mismatched validation set as a function of number of remaining heads in BERT.
Unpruned Heads Michel et al.
Pipelined DSP Voita et al.
STE
Joint DSP
68
64
60
56
52
48
44
40
36
32
28
24
20
16
12
8
4
3
2
1
32.87
29.08
11.18
6.91
4.41
2.64
2.30
1.70
1.20
0.61
0.19
0.13
0.07
0.07
0.05
0.04
0.04
0.04
0.04
0.04
34.19
34.29
32.21
32.52
33.02
31.58
28.70
24.35
25.84
23.94
16.63
20.40
14.11
7.55
3.80
0.63
0.16
0.09
0.05
0.05
34.10
34.19
34.14
34.19
34.23
34.20
34.08
34.06
33.82
33.70
33.78
33.44
33.25
32.62
32.33
31.26
29.09
23.08
20.89
20.38
34.69
34.55
34.56
34.19
33.92
34.02
33.88
33.85
33.22
32.88
32.01
33.71
31.27
31.25
30.71
28.77
25.45
23.83
22.35
20.37
34.52
34.51
34.83
34.46
34.79
34.82
34.68
34.13
34.58
34.10
33.89
33.72
33.54
32.32
32.74
32.68
30.33
28.22
24.18
20.64
(b) BLEU score on IWSLT test set as a function of number of unpruned heads in Enc–Dec.
Table 3: A comparison of various pruning methods.
1459
l
D
o
w
n
o
a
d
e
d
f
r
o
m
h
t
t
p
:
/
/
d
i
r
e
c
t
.
m
i
t
.
e
d
u
/
t
a
c
l
/
l
a
r
t
i
c
e
–
p
d
f
/
d
o
i
/
.
1
0
1
1
6
2
/
t
l
a
c
_
a
_
0
0
4
3
6
1
9
7
9
2
7
9
/
/
t
l
a
c
_
a
_
0
0
4
3
6
p
d
.
f
b
y
g
u
e
s
t
t
o
n
0
7
S
e
p
e
m
b
e
r
2
0
2
3