• Load Pretrained Network
  • Import COCO Data Set
  • Prepare Data for Training
  • Initialize Model Parameters
  • Define Model Functions
  • Specify Training Options
  • Train Network
  • Predict New Captions
  • Predict Captions for Data Set
  • Evaluate Model Accuracy
  • Attention Function
  • Embedding Function
  • Feature Extraction Function
  • Batch Creation Function
  • Encoder Model Function
  • Decoder Model Function
  • Model Loss
  • Sparse Cross Entropy and Softmax Loss Function
  • Beam Search Function
  • Glorot Weight Initialization Function
  • See Also
  • Related Topics
  • Image Captioning Using Attention

    This example shows how to train a deep learning model for image captioning using attention.

    Most pretrained deep learning networks are configured for single-label classification. For example, given an image of a typical office desk, the network might predict the single class "keyboard" or "mouse". In contrast, an image captioning model combines convolutional and recurrent operations to produce a textual description of what is in the image, rather than a single label.

    This model trained in this example uses an encoder-decoder architecture. The encoder is a pretrained Inception-v3 network used as a feature extractor. The decoder is a recurrent neural network (RNN) that takes the extracted features as input and generates a caption. The decoder incorporates an attention mechanism that allows the decoder to focus on parts of the encoded input while generating the caption.

    The encoder model is a pretrained Inception-v3 model that extracts features from the "mixed10" layer, followed by fully connected and ReLU operations.

    The decoder model consists of a word embedding, an attention mechanism, a gated recurrent unit (GRU), and two fully connected operations.

    Load Pretrained Network

    Load a pretrained Incetion-v3 network. This step requires the Deep Learning Toolbox™ Model for Inception-v3 Network support package. If you do not have the required support package installed, then the software provides a download link.

    net = imagePretrainedNetwork("inceptionv3");
    inputSizeNet = net.Layers(1).InputSize;

    Remove the last three layers, leaving the "mixed10" layer as the last layer.

    net = removeLayers(net, ["avg_pool" "predictions" "predictions_softmax"]);

    View the input layer of the network. The Inception-v3 network uses symmetric-rescale normalization with a minimum value of 0 and a maximum value of 255.

    net.Layers(1)
    ans = 
      ImageInputLayer with properties:
                          Name: 'input_1'
                     InputSize: [299 299 3]
            SplitComplexInputs: 0
       Hyperparameters
              DataAugmentation: 'none'
                 Normalization: 'rescale-symmetric'
        NormalizationDimension: 'auto'
                           Max: 255
                           Min: 0
    

    Custom training does not support this normalization, so you must disable normalization in the network and perform the normalization in the custom training loop instead. Save the minimum and maximum values as doubles in variables named inputMin and inputMax , respectively, and replace the input layer with an image input layer without normalization.

    inputMin = double(net.Layers(1).Min);
    inputMax = double(net.Layers(1).Max);
    layer = imageInputLayer(inputSizeNet,Normalization="none",Name="input");
    net = replaceLayer(net,"input_1",layer);

    Initialize the network.

    net = initialize(net);

    Determine the output size of the network. Use the analyzeNetwork function to see the activation sizes of the last layer.

    analyzeNetwork(net)

    Create a variable named outputSizeNet containing the network output size.

    outputSizeNet = [8 8 2048];

    Import COCO Data Set

    Download images and annotations from the data sets "2014 Train images" and "2014 Train/val annotations," respectively, from https://cocodataset.org/#download . Extract the images and annotations into a folder named "coco" . The COCO 2014 data set was collected by Coco Consortium .

    Extract the captions from the file "captions_train2014.json" using the jsondecode function.

    dataFolder = fullfile(tempdir,"coco");
    filename = fullfile(dataFolder,"annotations_trainval2014","annotations","captions_train2014.json");
    str = fileread(filename);
    data = jsondecode(str)
    data = struct with fields:
               info: [1×1 struct]
             images: [82783×1 struct]
           licenses: [8×1 struct]
        annotations: [414113×1 struct]
    

    The annotations field of the struct contains the data required for image captioning.

    data.annotations
    ans=414113×1 struct array with fields:
        image_id
        caption
    

    The data set contains multiple captions for each image. To ensure the same images do not appear in both training and validation sets, identify the unique images in the data set using the unique function by using the IDs in the image_id field of the annotations field of the data, then view the number of unique images.

    numObservationsAll = numel(data.annotations)
    numObservationsAll = 414113
    
    imageIDs = [data.annotations.image_id];
    imageIDsUnique = unique(imageIDs);
    numUniqueImages = numel(imageIDsUnique)
    numUniqueImages = 82783
    

    Each image has at least five captions. Create a struct annotationsAll with these fields:

    • ImageID ⁠— Image ID

    • Filename ⁠— File name of the image

    • Captions ⁠— String array of raw captions

    • CaptionIDs ⁠— Vector of indices of the corresponding captions in data.annotations

    To make merging easier, sort the annotations by the image IDs.

    [~,idx] = sort([data.annotations.image_id]);
    data.annotations = data.annotations(idx);

    Loop over the annotations and merge multiple annotations when necessary.

    i = 0;
    j = 0;
    imageIDPrev = 0;
    while i < numel(data.annotations)
        i = i + 1;
        imageID = data.annotations(i).image_id;
        caption = string(data.annotations(i).caption);
        if imageID ~= imageIDPrev
            % Create new entry
            j = j + 1;
            annotationsAll(j).ImageID = imageID;
            annotationsAll(j).Filename = fullfile(dataFolder,"train2014","COCO_train2014_" + pad(string(imageID),12,"left","0") + ".jpg");
            annotationsAll(j).Captions = caption;
            annotationsAll(j).CaptionIDs = i;
            % Append captions
            annotationsAll(j).Captions = [annotationsAll(j).Captions; caption];
            annotationsAll(j).CaptionIDs = [annotationsAll(j).CaptionIDs; i];
        imageIDPrev = imageID;
    end

    Partition the data into training and validation sets. Hold out 5% of the observations for testing.

    cvp = cvpartition(numel(annotationsAll),HoldOut=0.05);
    idxTrain = training(cvp);
    idxTest = test(cvp);
    annotationsTrain = annotationsAll(idxTrain);
    annotationsTest = annotationsAll(idxTest);

    The struct contains three fields:

    • id — Unique identifier for the caption

    • caption — Image caption, specified as a character vector

    • image_id — Unique identifier of the image corresponding to the caption

    To view the image and the corresponding caption, locate the image file with file name "train2014\COCO_train2014_XXXXXXXXXXXX.jpg" , where "XXXXXXXXXXXX" corresponds to the image ID left-padded with zeros to have length 12.

    imageID = annotationsTrain(1).ImageID;
    captions = annotationsTrain(1).Captions;
    filename = annotationsTrain(1).Filename;

    To view the image, use the imread and imshow functions.

    img = imread(filename);
    figure
    imshow(img)
    title(captions)
    

    Prepare Data for Training

    Prepare the captions for training and testing. Extract the text from the Captions field of the struct containing both the training and test data ( annotationsAll ), erase the punctuation, and convert the text to lowercase.

    captionsAll = cat(1,annotationsAll.Captions);
    captionsAll = erasePunctuation(captionsAll);
    captionsAll = lower(captionsAll);

    In order to generate captions, the RNN decoder requires special start and stop tokens to indicate when to start and stop generating text, respectively. Add the custom tokens "<start>" and "<stop>" to the beginnings and ends of the captions, respectively.

    captionsAll = "<start>" + captionsAll + "<stop>";

    Tokenize the captions using the tokenizedDocument function and specify the start and stop tokens using the CustomTokens option.

    documentsAll = tokenizedDocument(captionsAll,CustomTokens=["<start>" "<stop>"]);

    Create a wordEncoding object that maps words to numeric indices and back. Reduce the memory requirements by specifying a vocabulary size of 5000 corresponding to the most frequently observed words in the training data. To avoid bias, use only the documents corresponding to the training set.

    enc = wordEncoding(documentsAll(idxTrain),MaxNumWords=5000,Order="frequency");

    Create an augmented image datastore containing the images corresponding to the captions. Set the output size to match the input size of the convolutional network. To keep the images synchronized with the captions, specify a table of file names for the datastore by reconstructing the file names using the image ID. To return grayscale images as 3-channel RGB images, set the ColorPreprocessing option to "gray2rgb" .

    tblFilenames = table(cat(1,annotationsTrain.Filename));
    augimdsTrain = augmentedImageDatastore(inputSizeNet,tblFilenames,ColorPreprocessing="gray2rgb")
    augimdsTrain = 
      augmentedImageDatastore with properties:
             NumObservations: 78644
               MiniBatchSize: 1
            DataAugmentation: 'none'
          ColorPreprocessing: 'gray2rgb'
                  OutputSize: [299 299]
              OutputSizeMode: 'resize'
        DispatchInBackground: 0
    

    Initialize Model Parameters

    Initialize the model parameters. Specify 512 hidden units with a word embedding dimension of 256.

    embeddingDimension = 256;
    numHiddenUnits = 512;

    Initialize a struct containing the parameters for the encoder model.

    • Initialize the weights of the fully connected operations using the Glorot initializer, specified by the initializeGlorot function, listed at the end of the example. Specify the output size to match the embedding dimension of the decoder (256) and an input size to match the number of output channels of the pretrained network. The 'mixed10' layer of the Inception-v3 network outputs data with 2048 channels.

    numFeatures = outputSizeNet(1) * outputSizeNet(2);
    inputSizeEncoder = outputSizeNet(3);
    parametersEncoder = struct;
    % Fully connect
    parametersEncoder.fc.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeEncoder));
    parametersEncoder.fc.Bias = dlarray(zeros([embeddingDimension 1],"single"));

    Initialize a struct containing parameters for the decoder model.

    • Initialize the word embedding weights with the size given by the embedding dimension and the vocabulary size plus one, where the extra entry corresponds to the padding value.

    • Initialize the weights and biases for the Bahdanau attention mechanism with sizes corresponding to the number of hidden units of the GRU operation.

    • Initialize the weights and bias of the GRU operation.

    • Initialize the weights and biases of two fully connected operations.

    For the model decoder parameters, initialize each of the weighs and biases with the Glorot initializer and zeros, respectively.

    inputSizeDecoder = enc.NumWords + 1;
    parametersDecoder = struct;
    % Word embedding
    parametersDecoder.emb.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeDecoder));
    % Attention
    parametersDecoder.attention.Weights1 = dlarray(initializeGlorot(numHiddenUnits,embeddingDimension));
    parametersDecoder.attention.Bias1 = dlarray(zeros([numHiddenUnits 1],"single"));
    parametersDecoder.attention.Weights2 = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits));
    parametersDecoder.attention.Bias2 = dlarray(zeros([numHiddenUnits 1],"single"));
    parametersDecoder.attention.WeightsV = dlarray(initializeGlorot(1,numHiddenUnits));
    parametersDecoder.attention.BiasV = dlarray(zeros(1,1,"single"));
    % GRU
    parametersDecoder.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,2*embeddingDimension));
    parametersDecoder.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits));
    parametersDecoder.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,"single"));
    % Fully connect
    parametersDecoder.fc1.Weights = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits));
    parametersDecoder.fc1.Bias = dlarray(zeros([numHiddenUnits 1],"single"));
    % Fully connect
    parametersDecoder.fc2.Weights = dlarray(initializeGlorot(enc.NumWords+1,numHiddenUnits));
    parametersDecoder.fc2.Bias = dlarray(zeros([enc.NumWords+1 1],"single"));

    Define Model Functions

    Create the functions modelEncoder and modelDecoder , listed at the end of the example, which compute the outputs of the encoder and decoder models, respectively.

    The modelEncoder function, listed in the Encoder Model Function section of the example, takes as input an array of activations X from the output of the pretrained network and passes it through a fully connected operation and a ReLU operation. Because the pretrained network does not need to be traced for automatic differentiation, extracting the features outside the encoder model function is more computationally efficient.

    The modelDecoder function, listed in the Decoder Model Function section of the example, takes as input a single input time-step corresponding to an input word, the decoder model parameters, the features from the encoder, and the network state, and returns the predictions for the next time step, the updated network state, and the attention weights.

    Specify Training Options

    Specify the options for training. Train for 30 epochs with a mini-batch size of 128 and display the training progress in a plot.

    miniBatchSize = 128;
    numEpochs = 30;
    plots = "training-progress";

    Train on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox) .

    executionEnvironment = "auto";

    Check whether a GPU is available for training.

    if canUseGPU
        gpu = gpuDevice;
        disp(gpu.Name + " GPU detected and available for training.")
    end
    NVIDIA RTX A5000 GPU detected and available for training.
    

    Train Network

    Train the network using a custom training loop.

    At the beginning of each epoch, shuffle the input data. To keep the images in the augmented image datastore and the captions synchronized, create an array of shuffled indices that indexes into both data sets.

    For each mini-batch:

    • Rescale the images to the size that the pretrained network expects.

    • For each image, select a random caption.

    • Convert the captions to sequences of word indices. Specify right-padding of the sequences with the padding value corresponding to the index of the padding token.

    • Convert the data to dlarray objects. For the images, specify dimension labels "SSCB" (spatial, spatial, channel, batch).

    • For GPU training, convert the data to gpuArray objects.

    • Extract the image features using the pretrained network and reshape them to the size the encoder expects.

    • Evaluate the model loss and gradients using the dlfeval and modelLoss functions.

    • Update the encoder and decoder model parameters using the adamupdate function.

    • Display the training progress in a plot.

    Initialize the parameters for the Adam optimizer.

    trailingAvgEncoder = [];
    trailingAvgSqEncoder = [];
    trailingAvgDecoder = [];
    trailingAvgSqDecoder = [];

    Initialize the TrainingProgressMonitor object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.

    if plots == "training-progress"
        monitor = trainingProgressMonitor( ...
            Metrics="Loss", ...
            Info="Epoch", ...
            XLabel="Iteration");
    end

    Train the model.

    iteration = 0;
    numObservationsTrain = numel(annotationsTrain);
    numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
    numIterations = numIterationsPerEpoch*numEpochs;
    % Loop over epochs.
    for epoch = 1:numEpochs
        % Shuffle data.
        idxShuffle = randperm(numObservationsTrain);
        % Loop over mini-batches.
        for i = 1:numIterationsPerEpoch
            iteration = iteration + 1;
            % Determine mini-batch indices.
            idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
            idxMiniBatch = idxShuffle(idx);
            % Read mini-batch of data.
            tbl = readByIndex(augimdsTrain,idxMiniBatch);
            X = cat(4,tbl.input{:});
            annotations = annotationsTrain(idxMiniBatch);
            % For each image, select random caption.
            idx = cellfun(@(captionIDs) randsample(captionIDs,1),{annotations.CaptionIDs});
            documents = documentsAll(idx);
            % Create batch of data.
            [X,T] = createBatch(X,documents,net,inputMin,inputMax,enc,executionEnvironment);
            % Evaluate the model loss and gradients using dlfeval and the
            % modelLoss function.
            [loss,gradientsEncoder,gradientsDecoder] = dlfeval(@modelLoss,parametersEncoder, ...
                parametersDecoder,X,T);
            % Update encoder using adamupdate.
            [parametersEncoder,trailingAvgEncoder,trailingAvgSqEncoder] = adamupdate(parametersEncoder, ...
                gradientsEncoder,trailingAvgEncoder,trailingAvgSqEncoder,iteration);
            % Update decoder using adamupdate.
            [parametersDecoder,trailingAvgDecoder,trailingAvgSqDecoder] = adamupdate(parametersDecoder, ...
                gradientsDecoder,trailingAvgDecoder,trailingAvgSqDecoder,iteration);
            % Display the training progress.
            if plots == "training-progress"
                recordMetrics(monitor,iteration,Loss=loss);
                updateInfo(monitor,Epoch=epoch);
                monitor.Progress = 100 * iteration/numIterations;
    end

    Predict New Captions

    The caption generation process is different from the process for training. During training, at each time step, the decoder uses the true value of the previous time step as input. This is known as "teacher forcing". When making predictions on new data, the decoder uses the previous predicted values instead of the true values.

    Predicting the most likely word for each step in the sequence can lead to suboptimal results. For example, if the decoder predicts the first word of a caption is "a" when given an image of an elephant, then the probability of predicting "elephant" for the next word becomes much more unlikely because of the extremely low probability of the phrase "a elephant" appearing in English text.

    To address this issue, you can use the beam search algorithm: instead of taking the most likely prediction for each step in the sequence, take the top k predictions (the beam index) and for each following step, keep the top k predicted sequences so far according to the overall score.

    Generate a caption of a new image by extracting the image features, inputting them into the encoder, and then using the beamSearch function, listed in the Beam Search Function section of the example.

    img = imread("dog_sitting.jpg");
    X = extractImageFeatures(net,img,inputMin,inputMax,executionEnvironment);
    beamIndex = 3;
    maxNumWords = 20;
    [words,attentionScores] = beamSearch(X,beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
    caption = join(words)
    caption = 
    "a small white dog standing on a lush green grass covered field"
    

    Display the image with the caption.

    figure
    imshow(img)
    title(caption)

    Predict Captions for Data Set

    To predict captions for a collection of images, loop over mini-batches of data in the datastore and extract the features from the images using the extractImageFeatures function. Then, loop over the images in the mini-batch and generate captions using the beamSearch function.

    Create an augmented image datastore and set the output size to match the input size of the convolutional network. To output grayscale images as 3-channel RGB images, set the ColorPreprocessing option to "gray2rgb" .

    tblFilenamesTest = table(cat(1,annotationsTest.Filename));
    augimdsTest = augmentedImageDatastore(inputSizeNet,tblFilenamesTest,ColorPreprocessing="gray2rgb")
    augimdsTest = 
      augmentedImageDatastore with properties:
             NumObservations: 4139
               MiniBatchSize: 1
            DataAugmentation: 'none'
          ColorPreprocessing: 'gray2rgb'
                  OutputSize: [299 299]
              OutputSizeMode: 'resize'
        DispatchInBackground: 0
    

    Generate captions for the test data. Predicting captions on a large data set can take some time. If you have Parallel Computing Toolbox™, then you can make predictions in parallel by generating captions inside a parfor loop. If you do not have Parallel Computing Toolbox. then the parfor loop runs in serial.

    beamIndex = 2;
    maxNumWords = 20;
    numObservationsTest = numel(annotationsTest);
    numIterationsTest = ceil(numObservationsTest/miniBatchSize);
    captionsTestPred = strings(1,numObservationsTest);
    documentsTestPred = tokenizedDocument(strings(1,numObservationsTest));
    for i = 1:numIterationsTest
        % Mini-batch indices.
        idxStart = (i-1)*miniBatchSize+1;
        idxEnd = min(i*miniBatchSize,numObservationsTest);
        idx = idxStart:idxEnd;
        sz = numel(idx);
        % Read images.
        tbl = readByIndex(augimdsTest,idx);
        % Extract image features.
        X = cat(4,tbl.input{:});
        X = extractImageFeatures(net,X,inputMin,inputMax,executionEnvironment);
        % Generate captions.
        captionsPredMiniBatch = strings(1,sz);
        documentsPredMiniBatch = tokenizedDocument(strings(1,sz));
        parfor j = 1:sz
            words = beamSearch(X(:,:,j),beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
            captionsPredMiniBatch(j) = join(words);
            documentsPredMiniBatch(j) = tokenizedDocument(words,TokenizeMethod="none");
        captionsTestPred(idx) = captionsPredMiniBatch;
        documentsTestPred(idx) = documentsPredMiniBatch;
    end

    To view a test image with the corresponding caption, use the imshow function and set the title to the predicted caption.

    idx = 1;
    tbl = readByIndex(augimdsTest,idx);
    img = tbl.input{1};
    figure
    imshow(img)
    title(captionsTestPred(idx))
    

    Evaluate Model Accuracy

    To evaluate the accuracy of the captions using the BLEU score, calculate the BLEU score for each caption (the candidate) against the corresponding captions in the test set (the references) using the bleuEvaluationScore function. Using the bleuEvaluationScore function, you can compare a single candidate document to multiple reference documents.

    The bleuEvaluationScore function, by default, scores similarity using n-grams of length one through four. As the captions are short, this behavior can lead to uninformative results as most scores are close to zero. Set the n-gram length to one through two by setting the NgramWeights option to a two-element vector with equal weights.

    ngramWeights = [0.5 0.5];
    for i = 1:numObservationsTest
        annotation = annotationsTest(i);
        captionIDs = annotation.CaptionIDs;
        candidate = documentsTestPred(i);
        references = documentsAll(captionIDs);
        score = bleuEvaluationScore(candidate,references,NgramWeights=ngramWeights);
        scores(i) = score;
    end

    View the mean BLEU score.

    scoreMean = mean(scores)
    scoreMean = 0.3875
    

    Visualize the scores in a histogram.

    figure
    histogram(scores)
    xlabel("BLEU Score")
    ylabel("Frequency")

    Attention Function

    The attention function calculates the context vector and the attention weights using Bahdanau attention.

    function [contextVector, attentionWeights] = attention(hidden,features,weights1, ...
        bias1,weights2,bias2,weightsV,biasV)
    % Model dimensions.
    [embeddingDimension,numFeatures,miniBatchSize] = size(features);
    numHiddenUnits = size(weights1,1);
    % Fully connect.
    Y1 = reshape(features,embeddingDimension, numFeatures*miniBatchSize);
    Y1 = fullyconnect(Y1,weights1,bias1,DataFormat="CB");
    Y1 = reshape(Y1,numHiddenUnits,numFeatures,miniBatchSize);
    % Fully connect.
    Y2 = fullyconnect(hidden,weights2,bias2,DataFormat="CB");
    Y2 = reshape(Y2,numHiddenUnits,1,miniBatchSize);
    % Addition, tanh.
    scores = tanh(Y1 + Y2);
    scores = reshape(scores, numHiddenUnits, numFeatures*miniBatchSize);
    % Fully connect, softmax.
    attentionWeights = fullyconnect(scores,weightsV,biasV,DataFormat="CB");
    attentionWeights = reshape(attentionWeights,1,numFeatures,miniBatchSize);
    attentionWeights = softmax(attentionWeights,DataFormat="SCB");
    % Context.
    contextVector = attentionWeights .* features;
    contextVector = squeeze(sum(contextVector,2));
    end

    Embedding Function

    The embedding function maps an array of indices to a sequence of embedding vectors.

    function Z = embedding(X, weights)
    % Reshape inputs into a vector
    [N, T] = size(X, 1:2);
    X = reshape(X, N*T, 1);
    % Index into embedding matrix
    Z = weights(:, X);
    % Reshape outputs by separating out batch and sequence dimensions
    Z = reshape(Z, [], N, T);
    end

    Feature Extraction Function

    The extractImageFeatures function takes as input a trained dlnetwork object, an input image, statistics for image rescaling, and the execution environment, and returns a dlarray containing the features extracted from the pretrained network.

    function X = extractImageFeatures(net,X,inputMin,inputMax,executionEnvironment)
    % Resize and rescale.
    inputSize = net.Layers(1).InputSize(1:2);
    X = imresize(X,inputSize);
    X = rescale(X,-1,1,InputMin=inputMin,InputMax=inputMax);
    % Convert to dlarray.
    X = dlarray(X,"SSCB");
    % Convert to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        X = gpuArray(X);
    % Extract features and reshape.
    X = predict(net,X);
    sz = size(X);
    numFeatures = sz(1) * sz(2);
    inputSizeEncoder = sz(3);
    miniBatchSize = sz(4);
    X = reshape(X,[numFeatures inputSizeEncoder miniBatchSize]);
    end

    Batch Creation Function

    The createBatch function takes as input a mini-batch of data, tokenized captions, a pretrained network, statistics for image rescaling, a word encoding, and the execution environment, and returns a mini-batch of data corresponding to the extracted image features and captions for training.

    function [X, T] = createBatch(X,documents,net,inputMin,inputMax,enc,executionEnvironment)
    X = extractImageFeatures(net,X,inputMin,inputMax,executionEnvironment);
    % Convert documents to sequences of word indices.
    T = doc2sequence(enc,documents,PaddingDirection="right",PaddingValue=enc.NumWords+1);
    T = cat(1,T{:});
    % Convert mini-batch of data to dlarray.
    T = dlarray(T);
    % If training on a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        T = gpuArray(T);
    end

    Encoder Model Function

    The modelEncoder function takes as input an array of activations X and passes it through a fully connected operation and a ReLU operation. For the fully connected operation, operate on the channel dimension only. To apply the fully connected operation across the channel dimension only, flatten the other channels into a single dimension and specify this dimension as the batch dimension using the DataFormat option of the fullyconnect function.

    function Y = modelEncoder(X,parametersEncoder)
    [numFeatures,inputSizeEncoder,miniBatchSize] = size(X);
    % Fully connect
    weights = parametersEncoder.fc.Weights;
    bias = parametersEncoder.fc.Bias;
    embeddingDimension = size(weights,1);
    X = permute(X,[2 1 3]);
    X = reshape(X,inputSizeEncoder,numFeatures*miniBatchSize);
    Y = fullyconnect(X,weights,bias,DataFormat="CB");
    Y = reshape(Y,embeddingDimension,numFeatures,miniBatchSize);
    % ReLU
    Y = relu(Y);
    end

    Decoder Model Function

    The modelDecoder function takes as input a single time-step X , the decoder model parameters, the features from the encoder, and the network state, and returns the predictions for the next time step, the updated network state, and the attention weights.

    function [Y,state,attentionWeights] = modelDecoder(X,parametersDecoder,features,state)
    hiddenState = state.gru.HiddenState;
    % Attention
    weights1 = parametersDecoder.attention.Weights1;
    bias1 = parametersDecoder.attention.Bias1;
    weights2 = parametersDecoder.attention.Weights2;
    bias2 = parametersDecoder.attention.Bias2;
    weightsV = parametersDecoder.attention.WeightsV;
    biasV = parametersDecoder.attention.BiasV;
    [contextVector, attentionWeights] = attention(hiddenState,features,weights1,bias1,weights2,bias2,weightsV,biasV);
    % Embedding
    weights = parametersDecoder.emb.Weights;
    X = embedding(X,weights);
    % Concatenate
    Y = cat(1,contextVector,X);
    % GRU
    inputWeights = parametersDecoder.gru.InputWeights;
    recurrentWeights = parametersDecoder.gru.RecurrentWeights;
    bias = parametersDecoder.gru.Bias;
    [Y, hiddenState] = gru(Y, hiddenState, inputWeights, recurrentWeights, bias, DataFormat="CBT");
    % Update state
    state.gru.HiddenState = hiddenState;
    % Fully connect
    weights = parametersDecoder.fc1.Weights;
    bias = parametersDecoder.fc1.Bias;
    Y = fullyconnect(Y,weights,bias,DataFormat="CB");
    % Fully connect
    weights = parametersDecoder.fc2.Weights;
    bias = parametersDecoder.fc2.Bias;
    Y = fullyconnect(Y,weights,bias,DataFormat="CB");
    end

    Model Loss

    The modelLoss function takes as input the encoder and decoder parameters, the encoder features X , and the target caption T , and returns the loss, the gradients of the encoder and decoder parameters with respect to the loss, and the predictions.

    function [loss,gradientsEncoder,gradientsDecoder,YPred] = ...
        modelLoss(parametersEncoder,parametersDecoder,X,T)
    miniBatchSize = size(X,3);
    sequenceLength = size(T,2) - 1;
    vocabSize = size(parametersDecoder.emb.Weights,2);
    % Model encoder
    features = modelEncoder(X,parametersEncoder);
    % Initialize state
    numHiddenUnits = size(parametersDecoder.attention.Weights1,1);
    state = struct;
    state.gru.HiddenState = dlarray(zeros([numHiddenUnits miniBatchSize],"single"));
    YPred = dlarray(zeros([vocabSize miniBatchSize sequenceLength],"like",X));
    loss = dlarray(single(0));
    padToken = vocabSize;
    for t = 1:sequenceLength
        decoderInput = T(:,t);
        YReal = T(:,t+1);
        [YPred(:,:,t),state] = modelDecoder(decoderInput,parametersDecoder,features,state);
        mask = YReal ~= padToken;
        loss = loss + sparseCrossEntropyAndSoftmax(YPred(:,:,t),YReal,mask);
    % Calculate gradients
    [gradientsEncoder,gradientsDecoder] = dlgradient(loss, parametersEncoder,parametersDecoder);
    end

    Sparse Cross Entropy and Softmax Loss Function

    The sparseCrossEntropyAndSoftmax takes as input the predictions Y , corresponding targets T , and sequence padding mask, and applies the softmax functions and returns the cross-entropy loss.

    function loss = sparseCrossEntropyAndSoftmax(Y, T, mask)
    miniBatchSize = size(Y, 2);
    % Softmax.
    Y = softmax(Y,DataFormat="CB");
    % Find rows corresponding to the target words.
    idx = sub2ind(size(Y), T', 1:miniBatchSize);
    Y = Y(idx);
    % Bound away from zero.
    Y = max(Y, single(1e-8));
    % Masked loss.
    loss = log(Y) .* mask';
    loss = -sum(loss,"all") ./ miniBatchSize;
    end

    Beam Search Function

    The beamSearch function takes as input the image features X , a beam index, the parameters for the encoder and decoder networks, a word encoding, and a maximum sequence length, and returns the caption words for the image using the beam search algorithm.

    function [words,attentionScores] = beamSearch(X,beamIndex,parametersEncoder,parametersDecoder, ...
        enc,maxNumWords)
    % Model dimensions
    numFeatures = size(X,1);
    numHiddenUnits = size(parametersDecoder.attention.Weights1,1);
    % Extract features
    features = modelEncoder(X,parametersEncoder);
    % Initialize state
    state = struct;
    state.gru.HiddenState = dlarray(zeros([numHiddenUnits 1],"like",X));
    % Initialize candidates
    candidates = struct;
    candidates.State = state;
    candidates.Words = "<start>";
    candidates.Score = 0;
    candidates.AttentionScores = dlarray(zeros([numFeatures maxNumWords],"like",X));
    candidates.StopFlag = false;
    t = 0;
    % Loop over words
    while t < maxNumWords
        t = t + 1;
        candidatesNew = [];
        % Loop over candidates
        for i = 1:numel(candidates)
            % Stop generating when stop token is predicted
            if candidates(i).StopFlag
                continue
            % Candidate details
            state = candidates(i).State;
            words = candidates(i).Words;
            score = candidates(i).Score;
            attentionScores = candidates(i).AttentionScores;
            % Predict next token
            decoderInput = word2ind(enc,words(end));
            [YPred,state,attentionScores(:,t)] = modelDecoder(decoderInput,parametersDecoder,features,state);
            YPred = softmax(YPred,DataFormat="CB");
            [scoresTop,idxTop] = maxk(extractdata(YPred),beamIndex);
            idxTop = gather(idxTop);
            % Loop over top predictions
            for j = 1:beamIndex
                candidate = struct;
                candidateWord = ind2word(enc,idxTop(j));
                candidateScore = scoresTop(j);
                if candidateWord == "<stop>"
                    candidate.StopFlag = true;
                    attentionScores(:,t+1:end) = [];
                    candidate.StopFlag = false;
                candidate.State = state;
                candidate.Words = [words candidateWord];
                candidate.Score = score + log(candidateScore);
                candidate.AttentionScores = attentionScores;
                candidatesNew = [candidatesNew candidate];
        % Get top candidates
        [~,idx] = maxk([candidatesNew.Score],beamIndex);
        candidates = candidatesNew(idx);
        % Stop predicting when all candidates have stop token
        if all([candidates.StopFlag])
            break
    % Get top candidate
    words = candidates(1).Words(2:end-1);
    attentionScores = candidates(1).AttentionScores;
    end

    Glorot Weight Initialization Function

    The initializeGlorot function generates an array of weights according to Glorot initialization.

    function weights = initializeGlorot(numOut, numIn)
    varWeights = sqrt( 6 / (numIn + numOut) );
    weights = varWeights * (2 * rand([numOut, numIn], "single") - 1);
                

    See Also

    (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | | | | (Text Analytics Toolbox) |

    Related Topics

    You clicked a link that corresponds to this MATLAB command:

    Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.