メインコンテンツ

Preprocess Data for AI-Based CSI Prediction

Since R2026a

This example shows how to preprocess channel estimates and prepare a data set for training a gated recurrent unit (GRU) channel prediction network that enhances feedback on channel state information (CSI). It focuses on the Prepare Data step in the workflow for AI-Based CSI Feedback. You can run each step independently or work through the steps in order.

In this example, you preprocess channel realizations from previously generated data. For an example of the previous step in the workflow, see Generate MIMO OFDM Channel Realizations for AI-Based Systems.

Channel Realization Data

If the required data is not present in the workspace, this example generates the channel realization data by using the prepareChannelRealizations helper function.

numSamples = 500000;
if ~exist("sdsChan","var") || ~exist("channel","var") || ~exist("carrier","var") ...
        || (exist("userParams","var") && ~strcmp(userParams.Preset,"Channel prediction"))
disp("Data not present in workspace. Running data generation code.")
[sdsChan,systemParams,channel,carrier] = prepareChannelRealizations(numSamples);
end
Data not present in workspace. Running data generation code.
Starting channel realization generation
6 worker(s) running
00:00:03 - 100% Completed

After generating the data, you can view the system configuration by inspecting outputs (stdChan, systemParams, channel, and carrier) of the prepareChannelRealizations helper function.

Create Sequence Data

Most channel prediction neural networks require the input data to be a 3-D array structured as [features, sequence, batch], where:

  • Features represent channel estimates from the transmit antennas. For each antenna, the real and imaginary parts are interleaved.

  • Sequence corresponds to a temporal sequence of present and previous features, sampled at intervals of Δt seconds, which is usually the slot time.

  • Batch indexes different channel realizations.

Each feature sequence is a [2Ntx, Nseq] array, where a single feature vector includes values across all transmit antennas (for a single subcarrier and receive antenna). A new sample is generated for every time step across subcarriers, receive antennas, and frames. However, the network operates on one subcarrier and one receive antenna at a time, both for input and prediction targets. The final input to the network is a [2Ntx, Nseq, Nsamples] array.

This figure shows features as columns (first dimension), sequences as rows (second dimension), and samples as pages. For speed and efficiency in MATLAB®, process samples in the highest dimension (pages dimension here).

A 3-D stack of rectangular blue panels arranged diagonally, representing a data structure with three dimensions. The horizontal axis is labeled N_seq, the vertical axis is labeled 2N_tx (Interleaved IQ), and the diagonal axis is labeled N_samples. The panels have a gradient from dark blue at the bottom to lighter blue at the top, indicating layered sequences or frames.

The target data is a [2Ntx, 1] array, where the first dimension is the features, which are the transmit antenna channel gains at a given horizon (future time step). The final target data is a [2Ntx, Nsamples] array.

The signalDatastore object, sdsChan, contains numFrames files, where each file holds an Hest array with an [Nsc, Nsymbols, Nrx, Ntx] channel estimate. Each frame of the channel estimates contains Nsymbols contiguous symbols per subcarrier, receiver antennas, and transmitter antennas. The input samples for the neural network contain Nsequence contiguous slots. The target samples are selected Nhorizon slots after the last training slot. As a result, Nsequence+ Nhorizon slots are required to generate one training sample. Use a sliding window of length Nsequence+ Nhorizon to generate training samples. Each subcarrier and receiver antenna pair in a frame results in (Nsymbols-(Nsequence+Nhorizon-1))×14) training samples. Read the contents of the sdsChan object.

reset(sdsChan)
Hest = read(sdsChan);
[Nsc,Nsymbol,Nrx,~] = size(Hest);
maxSequenceLength = 65; % expected maximum value of Nsequence + Nhorizon in slots
symbolsPerSlot = 14;
numSlotsPerFrame = Nsymbol/symbolsPerSlot;
samplesPerSubCarrierRx = ...
(numSlotsPerFrame-maxSequenceLength+1)*symbolsPerSlot
samplesPerSubCarrierRx = 
154
if numSlotsPerFrame < maxSequenceLength
  error("Number of slots per frame (%d) is less than the expected maximum sequence length (%d).", ...
    numSlotsPerFrame,maxSequenceLength)
end

For channel prediction, each sample is an [Ntx, Nsequence] array, where Nsequence is the number of symbols sampled at Ts seconds, which is the slot time. Each channel estimate contains Nsc subcarriers and Nrx receiver antennas. Therefore, each frame results in (Nsymbols-(Nsequence+Nhorizon-1))×14)×(NscNrx) samples. Calculate the minimum number of frames required to generate numSamples training samples.

Nframes = ceil(numSamples/(samplesPerSubCarrierRx*Nsc*Nrx))
Nframes = 
3
if Nframes > numel(sdsChan.Files)
  error("Not enough frames to generate %d samples.", numSamples)
end

Check if the sequenceLength value provides enough variation in the channel for the network to learn channel characteristics.

Htemp = squeeze(Hest(1,:,1,:));
plot(real(Htemp))
hold on
plot(imag(Htemp))
hold of
grid on
xlabel("Symbols")
ylabel("Channel Gain")
legend("Tx1-I","Tx2-I","Tx1-Q","Tx2-Q")

Figure contains an axes object. The axes object with xlabel Symbols, ylabel Channel Gain contains 4 objects of type line. These objects represent Tx1-I, Tx2-I, Tx1-Q, Tx2-Q.

Preprocess Channel Realizations

Most neural networks require the input and target data to be preprocessed and reshaped. Preprocessing reduces the neural network complexity and training time. Reshaping ensures data dimensions are interpreted correctly.

  • Complex-to-Real Conversion: To feed complex-valued CSI into standard neural networks, you can map each complex sample into purely real features by interleaving in-phase and quadrature components. To match the format used for the futures dimension, flatten each complex sample into the feature dimension by alternating real and imaginary values, [ℜ{x1},ℑ{x1},ℜ{x2},ℑ{x2},… ].

  • Add Noise: Simulate noisy channel estimates by adding white Gaussian noise.

  • Min-Max Normalization: Apply min-max normalization to improve training performance.

  • Reshape Channel Array: For autoencoder type neural networks, input and target data are the 2-D subcarrier-transmit antenna array. Temporal prediction neural networks require time series of 2-D arrays as inputs and 2-D arrays as targets.

Separate Real and Imaginary Parts

Since the data is complex and the network requires real-valued data, separate the real and imaginary parts and store them on the sixth dimension.

dimSubcar = 1;
dimSymb = 2;
dimRxAnt = 3;
dimTxAnt = 4;
dimIQ = 6;
HestReal = cat(dimIQ,real(Hest),imag(Hest));

Get the dimensions of the data.

dimFrame = 5;
[Nsc,Nsymbol,Nrx,Ntx,Nframe,Niq] = ...
size(HestReal, ...
[dimSubcar,dimSymb,dimRxAnt,dimTxAnt,dimFrame,dimIQ])
Nsc = 
624
Nsymbol = 
1050
Nrx = 
2
Ntx = 
2
Nframe = 
1
Niq = 
2

Reshape Data

Create a transmit antenna sample array with interleaved IQ samples as the first dimension and symbols as the second dimension.

In HestReal, the symbols are time-contiguous only in the symbol dimension, which is the second dimension. Switching subcarriers, receive antennas, and frames creates discontinuities in the symbol dimension. To ensure continuity in the symbol dimension, keep the second dimension separate but combine subcarriers, receive antennas, and frames as the third dimension.

To create this array, first permute the dimensions to obtain an [Niq, Ntx, Nsym, Nsc, Nrx, Nframe] array. For this case, the frame dimension is not shown because there is one frame only (dimFrame=1).

H = permute(HestReal,[dimIQ dimTxAnt dimSymb dimSubcar dimRxAnt dimFrame]);
disp(size(H))
           2           2        1050         624           2

Reshape the array to size [NtxNiq, Nsymbol, Nother], where Nother is NscNframeNrx. Since MATLAB® reads arrays starting from the first dimension, this operation creates an array where the first dimension contains the interleaved IQ samples for the transmit antennas, the second dimension is time-contiguous symbols, and the third dimension is subcarriers, receive antennas, and frames.

Hr = reshape(H,Ntx*Niq,Nsymbol,Nsc*Nrx*Nframe);
[Ntxiq,Nsymbol,Nother] = size(Hr)
Ntxiq = 
4
Nsymbol = 
1050
Nother = 
1248

Plot the variation of the channel gain for the two transmit antennas for a random subcarrier, receive antenna, and frame.

figure
Hsample = Hr(:,:,randi(Nsc*Nrx*Nframe));
plot(Hsample(1:2:4,:)',Hsample(2:2:4,:)', '*-')
lim = floor(max(abs(Hsample),[],"all")*11)/10;
ylim([-lim lim])
xlim([-lim lim])
axis square
grid on
xlabel("In-Phase")
ylabel("Quadrature")
legend("Tx antenna 1","Tx antenna 2")
title("Channel Gain")

Figure contains an axes object. The axes object with title Channel Gain, xlabel In-Phase, ylabel Quadrature contains 2 objects of type line. These objects represent Tx antenna 1, Tx antenna 2.

Add Noise

To simulate noisy channel estimates, add noise to the channel data. Set the value for the signal-to-noise ratio (SNR).

SNR = 20;

Calculate the noise variance associated with the SNR and signal values.

SNR=10log10(S/N)SN=10(SNR/10)N=S10(SNR/10)

noiseVariance = (var(Hr,[],"all")/10^(SNR/10))/2;

Generate noisy data.

Hnoisy = Hr + randn(size(Hr),"single")*sqrt(noiseVariance);

Apply Min-Max Normalization

When you work with GRUs, input data normalization is crucial. It helps the network learn patterns more effectively, especially in time series data where values can vary significantly across different magnitudes, by ensuring all features are on a similar scale. Use min-max scaling to scale the data to a range between 0 and 1 by subtracting the minimum value and dividing by the range of the data. Check the minimum and maximum values of features.

histogram2(Hnoisy(1:2:end,:,:),Hnoisy(2:2:end,:,:),40)
grid on
xlabel("In-Phase Data Amplitude")
ylabel("Quadrature Data Amplitude")
zlabel("Number of Occurances")

Figure contains an axes object. The axes object with xlabel In-Phase Data Amplitude, ylabel Quadrature Data Amplitude contains an object of type histogram2.

Since the histogram does not show outliers, use min-max scaling to scale the data to a range between 0 and 1 by subtracting the minimum value and dividing by the range of the data. Check the minimum and maximum values of features.

featuresMax = max(Hnoisy,[],[2 3])
featuresMax = 4×1 single column vector

    1.9606
    1.8408
    2.0401
    1.7157

featuresMin = min(Hnoisy,[],[2 3])
featuresMin = 4×1 single column vector

    -1.6577
    -1.5097
    -1.6120
    -1.6986

Apply min-max scaling to the whole data set. This normalization ensures that all network inputs have a consistent scale and distribution, which can improve convergence speed and stability during training.

fmax=max(|(Hest)|,|(Hest)|)

fmin=min(|(Hest)|,|(Hest)|)

dataMax = max(featuresMax);
dataMin = min(featuresMin);

Hest=(Hest-fmin)/(fmax-fmin))

Hnoisys = (Hnoisy - dataMin) / (dataMax - dataMin);
Hrs = (Hr - dataMin) / (dataMax - dataMin);

Format Input Data

The channel prediction network requires a 3-D array input, [features, sequence, batch]. The features dimension is the channel gain per transmit antenna. The sequence dimension is the symbols as a time sequence that you use in prediction, and the batch dimension is the time steps.

Sample Nseq symbols from the second dimension of the Hnoisy array at Nts period, which contains the time-contiguous symbols. This 2-D array is the 2Ntx-by-Nseq input data sample. Repeat this process for each time step in the second dimension and for each subcarrier, receiver antenna, and frame (third dimension). Since each sample requires Nseq previous symbols sampled at Ts, input data can have only Nsymbol-Nts(Nseq-1) time-contiguous samples.

Select sequence length based on the coherence time of the channel. Compute the approximate coherence time of the channel in seconds as Tc12×Doppler Spread.

Tc = 1/(2*systemParams.MaxDoppler);

Calculate the coherence time in terms of slots.

numerology = (systemParams.SubcarrierSpacing/15)-1;
Tslot = 1e-3 / 2^numerology;
symbolsPerSlot = 14;
symbolTime = Tslot/symbolsPerSlot;
coherenceTimeInSlots = Tc / Tslot;

To capture the variations in the channel accurately, use four times the coherence time as the length of the input sequence.

sequenceLength = ceil(coherenceTimeInSlots*4)
sequenceLength = 
55

First preallocate inputData as a 2Ntx-by-Nseq-by-(Nsymbol-Nts(Nseq-1))NscNrx single precision array.

inputData = zeros(Ntxiq,sequenceLength, ...
  (Nsymbol-(sequenceLength-1)*symbolsPerSlot)*Nsc*Nrx*Nframe,"single");

Sample the data using a for-loop over time-contiguous symbols (s), subcarriers (Nsc), receive antennas (Nrx), and frames (p). Channel samples are subject to discontinuities when the sampling process switches the number of subcarriers, receive antennas, or frames. The sequence samples must be continuous.

sample = 1;
for p = 1:Nsc*Nrx*Nframe
  for s = 1:(Nsymbol - (sequenceLength-1)*symbolsPerSlot)
    inputData(:,:,sample) = Hnoisys(:,s:symbolsPerSlot:(s+symbolsPerSlot*(sequenceLength-1)+1),p);
    sample = sample + 1;
  end
end

Standardize the data array dimensions to align with the expected format for PyTorch® networks by permuting the data to bring time steps to the first dimension. For PyTorch coexecution examples that use the preprocessed data to train models, see Further Exploration.

inputData = permute(inputData,[3,2,1]);

Check the size of the input data array.

size(inputData)
ans = 1×3

    366912    55    4

Check the size of the input data array in the memory.

varInfo = whos("inputData");
fprintf("inputData is %1.0f MB in memory.\n", varInfo.bytes / 2^20)
inputData is 308 MB in memory.

Format Target Data

Generate target data based on the prediction horizon. Select the setting for horizon in milliseconds. Target data contains interleaved IQ samples for transmit antennas in the first dimension and symbols on the second dimension. The targetData variable holds the channel noisy estimation samples that are used as target values during training.

horizon = 2; % ms
targetData = zeros(Ntxiq,(Nsymbol-(sequenceLength-1+horizon)*symbolsPerSlot)*Nsc*Nrx*Nframe,"single");
sample = 1;
for p = 1:Nsc*Nrx*Nframe
  targetData(:,sample:sample+(Nsymbol-(sequenceLength-1+horizon)*symbolsPerSlot)-1) = ...
    Hrs(:,((sequenceLength-1)+horizon)*symbolsPerSlot+1:end,p);
  sample = sample+(Nsymbol-(sequenceLength-1+horizon)*symbolsPerSlot);
end

Permute targetData to bring time steps to the first dimension.

targetData = permute(targetData,[2,1]);
size(targetData)
ans = 1×2

    331968    4

You can now preprocess the data in bulk or preprocess it using a transform datastore.

Preprocess Data in Bulk

The helperPreprocess3GPPChannelData helper function preprocesses the channel realizations saved in files in the dataDir directory. The helper function takes the sdsChan signal datastore as its first input to load the channel realizations and optionally saves the preprocessed data to the processed folder in the dataDir directory.

Set TrainingObjective to "prediction" to generate preprocessed channel realizations that you can use as the input signal and the target signal of a prediction neural network. Set AverageOverSlots to false. To disable truncation in the delay domain, set TruncateChannel to false. If you have

useParallel = false;
sdsPreprocessed = helperPreprocess3GPPChannelData( ...
sdsChan, ...
TrainingObjective="prediction", ...
AverageOverSlots=false, ...
TruncateChannel=false, ...
InputSequenceLength=sequenceLength, ...
PredictionHorizon=horizon, ...
AddNoise=true, ...
SNR=SNR, ...
DataComplexity="real (interleaved)", ...
DataDomain="Frequency-Spatial (FS)", ...
UseParallel=true, ...
SaveData=true);
Starting CSI data preprocessing
3 worker(s) running
00:00:03 - 100% Completed
data = readall(sdsPreprocessed);
inputCells = cellfun(@(C) C{1},data,UniformOutput=false);
targetCells = cellfun(@(C) C{2},data,UniformOutput=false);
inputData = cat(3, inputCells{:});
targetData = cat(2, targetCells{:});
whos("inputData","targetData")
  Name            Size                     Bytes  Class     Attributes

  inputData       4x55x71136            62599680  single              
  targetData      4x71136                1138176  single              

Apply min-max normalization.

dataMax = max(inputData,[],"all")
dataMax = single

2.0354
dataMin = min(inputData,[],"all")
dataMin = single

-2.0416
inputData = (inputData - dataMin) / (dataMax - dataMin);
targetData = (targetData - dataMin) / (dataMax - dataMin);
whos("inputData","targetData")
  Name            Size                     Bytes  Class     Attributes

  inputData       4x55x71136            62599680  single              
  targetData      4x71136                1138176  single              

Save the normalization parameters.

dataOptions.Normalization = "min-max";
dataOptions.MinValue = dataMin;
dataOptions.MaxValue = dataMax;

Preprocess Data Using Transform Datastore

Alternatively, preprocess the data using a transform datastore, where the transform function is the preconfigured helperPreprocess3GPPChannelData function. Normalize the data using the previously determined the minimum and maximum normalization parameters. The transform function creates a new datastore, tdsChan, that applies preprocessing to the data read from the underlying datastore, sdsChan. This preprocessing method is particularly useful when dealing with large data sets that cannot fit entirely in memory. If you have Parallel Computing Toolbox, you can read and preprocess the data in parallel. Set UseParallel must be set to true in the readall function. Since Parallel Computing Toolbox does not support nested parallel loops, UseParallel must be set to false when calling helperPreprocess3GPPChannelData in a transform datastore.

tdsChan = transform(sdsChan, @(x){helperPreprocess3GPPChannelData( ...
  x, ...
  TrainingObjective="prediction", ...
  AverageOverSlots=false, ...
  TruncateChannel=false, ...
  InputSequenceLength=sequenceLength, ...
  PredictionHorizon=horizon, ...
  AddNoise=true, ...
  SNR=SNR, ...
  DataComplexity="real (interleaved)", ...
  DataDomain="Frequency-Spatial (FS)", ...
  SaveData=false, ...
  UseParallel=false, ...
  Verbose=false, ...
  Normalization="min-max", ...
  MinimumValue=dataOptions.MinValue, ...
  MaximumValue=dataOptions.MaxValue)});

Read all the data and apply preprocessing.

dataCell = readall(tdsChan,UseParallel=useParallel);
inputCells = cellfun(@(C) C{1},dataCell,UniformOutput=false);
targetCells = cellfun(@(C) C{2},dataCell,UniformOutput=false);
inputData2 = cat(3, inputCells{:});
targetData2 = cat(2, targetCells{:});
whos("inputData2","targetData2")
  Name             Size                     Bytes  Class     Attributes

  inputData2       4x55x71136            62599680  single              
  targetData2      4x71136                1138176  single              

Further Exploration

After preprocessing channel realizations, you can use the CSI snapshots to explore channel prediction neural network training in these examples:

For information about the full workflow, see AI-Based CSI Feedback.

Local Functions

function [sdsChan,systemParams,channel,carrier] = prepareChannelRealizations(numSamples)
maxSequenceLength = 65;
slotsPerFrame = 75;
symbolsPerSlot = 14;
systemParams.NumSymbols = slotsPerFrame*symbolsPerSlot;
samplesPerSubCarrierRx = (slotsPerFrame-maxSequenceLength+1)*symbolsPerSlot;
carrier = nrCarrierConfig;
nSizeGrid = 52;                                         % Number resource blocks (RB)
Nsc = nSizeGrid*12;
systemParams.SubcarrierSpacing = 15;
carrier.NSizeGrid = nSizeGrid;
carrier.SubcarrierSpacing = systemParams.SubcarrierSpacing;
waveInfo = nrOFDMInfo(carrier);
systemParams.TxAntennaSize = 2;
systemParams.RxAntennaSize = 2;
systemParams.MaxDoppler = 37;               % Hz
systemParams.RMSDelaySpread = 300e-9;       % s
systemParams.DelayProfile = "TDL-A";
channel = nrTDLChannel;
channel.DelayProfile = systemParams.DelayProfile;
channel.DelaySpread = systemParams.RMSDelaySpread;     % s
channel.MaximumDopplerShift = systemParams.MaxDoppler; % Hz
channel.RandomStream = "Global stream";
channel.NumTransmitAntennas = systemParams.TxAntennaSize;
channel.NumReceiveAntennas = systemParams.RxAntennaSize;
channel.ChannelFiltering = false;
channel.SampleRate = waveInfo.SampleRate;
Nrx = prod(systemParams.RxAntennaSize);
useParallel = true; % select if parallel processing is available
saveData =  true;
dataDir = fullfile(pwd,"Data");
dataFilePrefix = "nr_channel_est";
resetChanel = true;
numFrames = ceil(numSamples/(samplesPerSubCarrierRx*Nsc*Nrx));
sdsChan = helper3GPPChannelRealizations(...
  numFrames, ...
  channel, ...
  carrier, ...
  UseParallel=useParallel, ...
  SaveData=saveData, ...
  DataDir=dataDir, ...
  dataFilePrefix=dataFilePrefix, ...
  NumSlotsPerFrame=slotsPerFrame, ...
  ResetChannelPerFrame=resetChanel);
end

See Also

Objects

Topics