ChatGPT and chatbot-powered applications have captured significant attention in the Natural Language Processing (NLP) domain. The community is constantly seeking strong, reliable and open-source models for their applications and use cases.
The rise of these powerful models stems from the democratization and widespread adoption of transformer-based models, first introduced by Vaswani et al. in 2017. These models significantly outperformed previous SoTA NLP models based on Recurrent Neural Networks (RNNs), which were considered dead after that paper.
Through this blogpost, we will introduce the integration of a new architecture, RWKV, that combines the advantages of both RNNs and transformers, and that has been recently integrated into the Hugging Face
transformers
library.
Overview of the RWKV project
The RWKV project was kicked off and is being led by
Bo Peng
, who is actively contributing and maintaining the project. The community, organized in the official discord channel, is constantly enhancing the project’s artifacts on various topics such as performance (RWKV.cpp, quantization, etc.), scalability (dataset processing & scrapping) and research (chat-fine tuning, multi-modal finetuning, etc.). The GPUs for training RWKV models are donated by Stability AI.
The RNN architecture is one of the first widely used Neural Network architectures for processing a sequence of data, contrary to classic architectures that take a fixed size input. It takes as input the current “token” (i.e. current data point of the datastream), the previous “state”, and computes the predicted next token, and the predicted next state. The new state is then used to compute the prediction of the next token, and so on.
A RNN can be also used in different “modes”, therefore enabling the possibility of applying RNNs on different scenarios, as denoted by
Andrej Karpathy’s blogpost
, such as one-to-one (image-classification), one-to-many (image captioning), many-to-one (sequence classification), many-to-many (sequence generation), etc.
Because RNNs use the same weights to compute predictions at every step, they struggle to memorize information for long-range sequences due to the vanishing gradient issue. Efforts have been made to address this limitation by introducing new architectures such as LSTMs or GRUs. However, the transformer architecture proved to be the most effective thus far in resolving this issue.
In the transformer architecture, the input tokens are processed simultaneously in the self-attention module. The tokens are first linearly projected into different spaces using the query, key and value weights. The resulting matrices are directly used to compute the attention scores (through softmax, as shown below), then multiplied by the value hidden states to obtain the final hidden states. This design enables the architecture to effectively mitigate the long-range sequence issue, and also perform faster inference and training compared to RNN models.
During training, Transformer architecture has several advantages over traditional RNNs and CNNs. One of the most significant advantages is its ability to learn contextual representations. Unlike the RNNs and CNNs, which process input sequences one word at a time, Transformer architecture processes input sequences as a whole. This allows it to capture long-range dependencies between words in the sequence, which is particularly useful for tasks such as language translation and question answering.
During inference, RNNs have some advantages in speed and memory efficiency. These advantages include simplicity, due to needing only matrix-vector operations, and memory efficiency, as the memory requirements do not grow during inference. Furthermore, the computation speed remains the same with context window length due to how computations only act on the current token and the state.
The RWKV architecture
RWKV is inspired by
Apple’s Attention Free Transformer
. The architecture has been carefully simplified and optimized such that it can be transformed into an RNN. In addition, a number of tricks has been added such as
TokenShift
&
SmallInitEmb
(the list of tricks is listed in
the README of the official GitHub repository
) to boost its performance to match GPT. Without these, the model wouldn't be as performant.
For training, there is an infrastructure to scale the training up to 14B parameters as of now, and some issues have been iteratively fixed in RWKV-4 (latest version as of today), such as numerical instability.
RWKV as a combination of RNNs and transformers
How to combine the best of transformers and RNNs? The main drawback of transformer-based models is that it can become challenging to run a model with a context window that is larger than a certain value, as the attention scores are computed simultaneously for the entire sequence.
RNNs natively support very long context lengths - only limited by the context length seen in training, but this can be extended to millions of tokens with careful coding. Currently, there are RWKV models trained on a context length of 8192 (
ctx8192
) and they are as fast as
ctx1024
models and require the same amount of RAM.
The major drawbacks of traditional RNN models and how RWKV is different:
Traditional RNN models are unable to utilize very long contexts (LSTM can only manage ~100 tokens when used as a LM). However, RWKV can utilize thousands of tokens and beyond, as shown below:
Traditional RNN models cannot be parallelized when training. RWKV is similar to a “linearized GPT” and it trains faster than GPT.
By combining both advantages into a single architecture, the hope is that RWKV can grow to become more than the sum of its parts.
RWKV attention formulation
The model architecture is very similar to classic transformer-based models (i.e. an embedding layer, multiple identical layers, layer normalization, and a Causal Language Modeling head to predict the next token). The only difference is on the attention layer, which is completely different from the traditional transformer-based models.
To gain a more comprehensive understanding of the attention layer, we recommend to delve into the detailed explanation provided in
a blog post by Johan Sokrates Wind
.
Existing checkpoints
Pure language models: RWKV-4 models
Most adopted RWKV models range from ~170M parameters to 14B parameters. According to the RWKV overview
blog post
, these models have been trained on the Pile dataset and evaluated against other SoTA models on different benchmarks, and they seem to perform quite well, with very comparable results against them.
Instruction Fine-tuned/Chat Version: RWKV-4 Raven
Bo has also trained a “chat” version of the RWKV architecture, the RWKV-4 Raven model. It is a RWKV-4 pile (RWKV model pretrained on The Pile dataset) model fine-tuned on ALPACA, CodeAlpaca, Guanaco, GPT4All, ShareGPT and more. The model is available in multiple versions, with models trained on different languages (English only, English + Chinese + Japanese, English + Japanese, etc.) and different sizes (1.5B parameters, 7B parameters, 14B parameters).
All the HF converted models are available on Hugging Face Hub, in the
RWKV
organization
.
🤗
Transformers integration
The architecture has been added to the
transformers
library thanks to
this Pull Request
. As of the time of writing, you can use it by installing
transformers
from source, or by using the
main
branch of the library. The architecture is tightly integrated with the library, and you can use it as you would any other architecture.
Let us walk through some examples below.
Text Generation Example
To generate text given an input prompt you can use
pipeline
to generate text:
fromtransformersimportpipelinemodel_id="RWKV/rwkv-4-169m-pile"prompt="\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."pipe=pipeline("text-generation", model=model_id)
print(pipe(prompt, max_new_tokens=20))
>>> [{'generated_text': '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.\n\nThe researchers found that the dragons were able to communicate with each other, and that they were'}]
Or you can run and start from the snippet below:
importtorchfromtransformersimportAutoModelForCausalLM, AutoTokenizermodel=AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile")
tokenizer=AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
prompt="\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."inputs=tokenizer(prompt, return_tensors="pt")
output=model.generate(inputs["input_ids"], max_new_tokens=20)
print(tokenizer.decode(output[0].tolist()))
>>>Inashockingfinding, scientistdiscoveredaherdofdragonslivinginaremote, previouslyunexploredvalley, inTibet. EvenmoresurprisingtotheresearcherswasthefactthatthedragonsspokeperfectChinese.\n\nTheresearchersfoundthatthedragonswereabletocommunicatewitheachother, andthattheywere
Use the raven models (chat models)
You can prompt the chat model in the alpaca style, here is an example below:
fromtransformersimportAutoTokenizer, AutoModelForCausalLMmodel_id="RWKV/rwkv-raven-1b5"model=AutoModelForCausalLM.from_pretrained(model_id).to(0)
tokenizer=AutoTokenizer.from_pretrained(model_id)
question="Tell me about ravens"prompt=f"### Instruction: {question}\n### Response:"inputs=tokenizer(prompt, return_tensors="pt").to(0)
output=model.generate(inputs["input_ids"], max_new_tokens=100)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))
>>>### Instruction: Tell me about ravens### Response: RAVENS are a type of bird that is native to the Middle East and North Africa. They are known for their intelligence, adaptability, and their ability to live in a variety of environments. RAVENS are known for their intelligence, adaptability, and their ability to live in a variety of environments. They are known for their intelligence, adaptability, and their ability to live in a variety of environments.
Any user could easily convert the original RWKV weights to the HF format by simply running the conversion script provided in the
transformers
library. First, push the "raw" weights to the Hugging Face Hub (let's denote that repo as
RAW_HUB_REPO
, and the raw file
RAW_FILE
), then run the conversion script:
If you want to push the converted model on the Hub (let's say, under
dummy_user/converted-rwkv
), first forget to log in with
huggingface-cli login
before pushing the model, then run:
Due to only needing matrix-vector operations, RWKV is an ideal candidate for non-standard and experimental computing hardware, such as photonic processors/accelerators.
Therefore, the architecture can also naturally benefit from classic acceleration and compression techniques (such as
ONNX
, 4-bit/8-bit quantization, etc.), and we hope this will be democratized for developers and practitioners together with the transformers integration of the architecture.
The Hugging Face team would like to thank Bo and RWKV community for their time and for answering our questions about the architecture. We would also like to thank them for their help and support and we look forward to see more adoption of RWKV models in the HF ecosystem.
We also would like to acknowledge the work of
Johan Wind
for his blogpost on RWKV, which helped us a lot to understand the architecture and its potential.
And finally, we would like to highlight anf acknowledge the work of
ArEnSc
for starting over the initial
transformers
PR.
Also big kudos to
Merve Noyan
,
Maria Khalusova
and
Pedro Cuenca
for kindly reviewing this blogpost to make it much better!