Читать книгу Artificial Intelligence and Quantum Computing for Advanced Wireless Networks - Savo G. Glisic - Страница 62
4.3 Rule Extraction from LSTM Networks
ОглавлениеIn this section, we consider long short term memory networks (LSTMs), which were discussed in Chapter 3, and described an approach for tracking the importance of a given input to the LSTM for a given output. By identifying consistently important patterns of words, we are able to distill state‐of‐the‐art LSTMs on sentiment analysis and question answering into a set of representative phrases. This representation is then quantitatively validated by using the extracted phrases to construct a simple rule‐based classifier that approximates the output of the LSTM.
Word importance scores in LSTMS: Here, we present a decomposition of the output of an LSTM into a product of factors, where each term in the product can be interpreted as the contribution of a particular word. Thus, we can assign importance scores to words according to their contribution to the LSTM’s prediction. We have introduced the basics of LSTM networks in the Chapter 3. Given a sequence of word embeddings x1, xT ∈ ℝd, an LSTM processes one word at a time, keeping track of cell and state vectors (c1, h1), (cT, hT), which contain information in the sentence up to word i. ht and ct are computed as a function of xt, ct − 1 using the updates given by Eq. (3.72) of Chapter 3, which we repeat here with slightly different notation:
(4.25)
As initial values, we define c0 = h0 = 0. After processing the full sequence, a probability distribution over C classes is specified by p, with
where Wi is the i‐th row of the matrix W.
Decomposing the output of an LSTM: We now decompose the numerator of pi in Eq. (4.26) into a product of factors and show that we can interpret those factors as the contribution of individual words to the predicted probability of class i. Define
(4.27)
so that
As tanh (cj) − tanh (cj − 1) can be viewed as the update resulting from word j, so βi, j can be interpreted as the multiplicative contribution to pi by word j.
An additive decomposition of the LSTM Cell: We will show below that βi, j captures some notion of the importance of a word to the LSTM’s output. However, these terms fail to account for how the information contributed by word j is affected by the LSTM’s forget gates between words j and T. Consequently, it was empirically found [93] that the importance scores from this approach often yield a considerable amount of false positives. A more nuanced approach is obtained by considering the additive decomposition of cT in Eq. (4.28), where each term ej can be interpreted as the contribution to the cell state cT by word j. By iterating the equation , we obtain that
This suggests a natural definition of an alternative score to βi, j , corresponding to augmenting the cj terms with the products of the forget gates to reflect the upstream changes made to cj after initially processing word j:
(4.29)
We now introduce a technique for using our variable importance scores to extract phrases from a trained LSTM. To do so, we search for phrases that consistently provide a large contribution to the prediction of a particular class relative to other classes. The utility of these patterns is validated by using them as input for a rules‐based classifier. For simplicity, we focus on the binary classification case.
Phrase extraction: A phrase can be reasonably described as predictive if, whenever it occurs, it causes a document to both be labeled as a particular class and not be labeled as any other. As our importance scores introduced above correspond to the contribution of particular words to class predictions, they can be used to score potential patterns by looking at a pattern’s average contribution to the prediction of a given class relative to other classes. In other words, given a collection of D documents , for a given phrase w1, …., wk we can compute scores S1, S2 for classes 1 and 2, as well as a combined score S and class C as
(4.30)
where βi,j,k denotes βi, j applied to document k and stands for average.
The numerator of S1 denotes the average contribution of the phrase to the prediction of class 1 across all occurrences of the phrase. The denominator denotes the same statistic, but for class 2. Thus, if S1 is high, then w1, …, wk is a strong signal for class 1, and likewise for S2 . It was proposed [93] to use S as a score function in order to search for high‐scoring representative phrases that provide insight into the trained LSTM, and C to denote the class corresponding to a phrase.
In practice, the number of phrases is too large to feasibly compute the score for all of them. Thus, we approximate a brute force search through a two‐step procedure. First, we construct a list of candidate phrases by searching for strings of consecutive words j with importance scores βi, j > c for any i and some threshold c. Then, we score and rank the set of candidate phrases, which is much smaller than the set of all phrases.
Rules‐based classifier: The extracted patterns from Section 4.1 can be used to construct a simple rules‐based classifier that approximates the output of the original LSTM. Given a document and a list of patterns sorted by descending score given by S, the classifier sequentially searches for each pattern within the document using simple string matching. Once it finds a pattern, the classifier returns the associated class given by C, ignoring the lower‐ranked patterns. The resulting classifier is interpretable, and despite its simplicity, retains much of the accuracy of the LSTM used to build it.