Transformer models are in the spotlight these days, underlying large language models (LLMs) like GPT-4 and IBM Granite. One powerful ability is at least partly responsible for their growing popularity: in-context learning, which enriches a model with examples without having to retrain it. It’s a technique that allows a model to extract information from a prompt a user inputs, helping the model answer the question.
In-context learning makes predictions more accurate, but it doesn’t add nearly as much extra computing cost or human labor as retraining a model would. But even with real-world evidence that in-context learning works, the details of why it works have been foggy — until now.
The key to in-context learning lies in how a transformer’s self-attention layer processes in-context examples. That’s according to new research from a team of scientists at IBM Research and Rensselaer Polytechnic Institute (RPI), working together through the Rensselaer-IBM AI Research Collaboration. The self-attention layer, which sets the architecture of a transformer apart from that of other models, prizes in-context examples when they’re similar to the model’s training data. This helps explain why more context isn’t always better, says the team behind the work. Rather, the added context needs to be relevant to the model.
For example, when asking a transformer-based LLM to translate text from English to French, relevant context would include pairs of English-to-French translations, says study co-author Hongkang Li, a Ph.D. student at RPI specializing in deep learning theory and machine learning.
In a series of classification task experiments with transformers, he and his IBM collaborators tried feeding them various types of in-context learning and quantified how much each one improved a given model’s performance on a set of examples it hadn’t seen before.
A longer context window — adding more examples — doesn’t necessarily make it a better tutorial, according to the new study. The quality of predictions will depend on how good the tutorial is. It’s what people intuitively assumed, but this is the first time it was shown in detail. Their theoretical result also characterizes how the number of relevant in-context examples can improve the model prediction.
“It is quite surprising that a transformer structure is able to predict correct answers for unseen context inputs,” says Songtao Lu, the principal investigator of this IBM-RPI project, from the IBM Math and Theoretical Computer Science group. “Going beyond the classic supervised learning theory, our work rigorously and mathematically reveals, for the first time, self-attention’s out-of-domain generalization ability.”
Their findings, which go into greater detail about how they performed these experiments, are being presented at the 2024 International Conference on Machine Learning (ICML), happening from July 21 to 27 in Vienna.
The team behind the work wants to pin down theoretical explanations for what’s being observed in LLMs and generative AI. Part of building trustworthy AI, as they see it, is looking into the underlying working mechanisms of these complicated systems, understanding them bit by bit and component by component, to glean how they work or when they will succeed or fail. In this sense, their work has promising implications for how we can train a better model, prevent misuse, and avoid bias.
“We’re adding transparency to this dark magic, says senior researcher Pin-Yu Chen of the Trusted AI group at IBM Research. “People use it a lot but don’t understand how it works.” In-context learning has gotten popular because it’s done without updating the model weights, a process that could take hours to months depending on the model’s size.
The team’s long-term goal is to extend this analysis to other types of AI tasks, including text and video generation, as well as the kinds of inference performed by generative models like GPT-4 and Sora.
In another time- and money-saving result, their method also points to ways large models can be pruned down to speed up the inference process. In most cases, some neurons, or parameters, will end up being redundant for the final inferencing. Based on the theory identified in their study, a technique called magnitude-based pruning could be used to only keep the high-impact neurons.
“Magnitude-based pruning means that model parameters with smaller absolute values also have smaller effects on the final prediction,” RPI’s Li says. The final accuracy of the prediction is almost the same compared to the model without pruning, and theoretically it can prune almost 20% without really affecting the final results. In practice, with more careful pruning, they say they can shave even more off.
People have done this with convolutional networks (CNNs) and in LLMs, but there was no theoretical justification, Li says. “Our work is the first to provide a pruning rate that is tolerable for in-context learning.”
After the AI field switched from CNNs to transformers (the architecture that underpins generative models), the intuitive understanding was that transformers just work better. But there wasn’t a formal explanation for why. The IBM-RPI team released a paper last year proving that transformers learn and generalize better than CNNs, giving theoretical justification to the shift in the field. Now they have a better idea why that is, and the new study takes that transparency a step further.
“We believe theories and experiments should work together, so our theory is motivated by empirical observations,” Li says. “And we hope that later our theory will guide empirical experiments.”
link