メインコンテンツ

CSI フィードバック圧縮のための自己符号化器の学習

R2022b 以降

この例では、クラスター遅延線 (CDL) チャネルによってダウンリンク チャネル状態情報 (CSI) を圧縮するための自己符号化器ニューラル ネットワークの学習を行う方法を示します。

この例では、次のことを行います。

  1. CSI フィードバック自己符号化のためのニューラル ネットワーク モデルの定義と学習

  2. 前処理、符号化、復号化、および後処理を含む、完全な CSI 圧縮システムの学習済みネットワークのテスト

  3. システムの性能に対する量子化済みコードワードの影響のテスト。

CSI フィードバックのための AI ワークフロー

AI ベースの CSI フィードバック ワークフローの手順には、データの生成、データの準備、モデルの学習、およびモデルのテストが含まれます。各ステップを個別に実行することも、ステップを順番に実行することもできます。この例では "モデルの学習" に焦点を当てます。

CSI フィードバック プロセスと AI ワークフローの説明については、AI-Based CSI Feedbackを参照してください。概要として、ワークフローの手順は以下のとおりです。

1. データの生成 - Generate MIMO OFDM Channel Realizations for AI-Based Systemsの例に従って、チャネル推定データを生成します。

2. データの準備 - Preprocess Data for AI-Based CSI Feedback Compressionの例に従って、データの準備を行います。

3. モデルの学習 - モデルの学習では、この例のニューラル ネットワーク モデルの定義と学習セクションで説明されているように、前処理されたチャネル推定データをニューラル ネットワークに入力し、CSI データを再構成します。

4. モデルのテスト - モデルのテストは、Test AI-based CSI Compression Techniques for Enhanced PDSCH Throughputの例で重点的に説明されています。

自己符号化器モデルの学習、圧縮、テストを行うその他の例の一覧については、その他の調査セクションを参照してください。

ニューラル ネットワーク モデルの定義と学習

必要なデータがワークスペースに存在しない場合は、データを生成して準備します。データを前処理した後、prepareData 関数の出力 (inputDatasystemParamsdataOptionschannel、および carrier) を調べることで、システムの構成を確認できます。

if ~exist("inputData","var") || ~exist("systemParams","var") ...
        || ~exist("dataOptions","var") || ~exist("channel","var") ...
        || ~exist("carrier","var")
    numSamples = 1000;
[inputData,systemParams,dataOptions,channel,carrier] = ...
prepareData(numSamples);
end
Starting channel realization generation
6 worker(s) running
00:00:13 - 100% Completed
Starting CSI data preprocessing
6 worker(s) running
00:00:02 - 100% Completed

ニューラル ネットワーク モデルの変数の定義

ニューラル ネットワーク モデルを定義する変数を初期化します。inputData 変数には、Dmax×Ntx×2 の配列から成る Nsamples 個のサンプルが含まれています。

[maxDelay,nTx,Niq,Nsamples] = size(inputData)
maxDelay = 
28
nTx = 
8
Niq = 
2
Nsamples = 
2000
systemParams.MaxDelay = maxDelay;

データを学習セット、検証セット、テスト セットに分割します。

N = size(inputData, 4);
numTrain = floor(N*10/15)
numTrain = 
1333
numVal = floor(N*3/15)
numVal = 
400
numTest = floor(N*2/15)
numTest = 
266
inputDataT = inputData(:,:,:,1:numTrain);
inputDataV = inputData(:,:,:,numTrain+(1:numVal));
inputDataTest = inputData(:,:,:,numTrain+numVal+(1:numTest));

この例では、[1] で推奨されている自己符号化器ニューラル ネットワークの修正版を使用します。

inputSize = [maxDelay nTx 2]; % 3rd dim for real and imaginary
nLinear = prod(inputSize);
nEncoded = 64;

autoencoderNet = dlnetwork([ ...
    % Encoder
    imageInputLayer(inputSize, ...
        "Normalization","none","Name","Enc_Input")

    convolution2dLayer([3 3],2, ...
        "Padding","same","Name","Enc_Conv")
    batchNormalizationLayer("Epsilon",0.001,"Name","Enc_BN")
    leakyReluLayer(0.3,"Name","Enc_leakyRelu")

    flattenLayer("Name","Enc_flatten")

    fullyConnectedLayer(nEncoded,"Name","Enc_FC")

    sigmoidLayer("Name","Enc_Sigmoid")

    % Decoder
    fullyConnectedLayer(nLinear,"Name","Dec_FC")

    functionLayer(@(x)dlarray(reshape(x,maxDelay,nTx,2,[]),'SSCB'), ...
      "Formattable",true,"Acceleratable",true,"Name","Dec_Reshape")
    ]);

autoencoderNet = ...
  helperCSINetAddResidualLayers(autoencoderNet, "Dec_Reshape");

autoencoderNet = addLayers(autoencoderNet, ...
    [convolution2dLayer([3 3],2,"Padding","same","Name","Dec_Conv") ...
    sigmoidLayer("Name","Dec_Sigmoid")]);
autoencoderNet = ...
  connectLayers(autoencoderNet,"leakyRelu_2_3","Dec_Conv");

figure
plot(autoencoderNet)
title('CSI Compression Autoencoder')

Figure contains an axes object. The axes object with title CSI Compression Autoencoder contains an object of type graphplot.

ニューラル ネットワークの学習

自己符号化器ニューラル ネットワークの学習オプションを設定し、trainnet (Deep Learning Toolbox)関数を使用してネットワークに学習させます。学習は、Intel® Xeon® W-2133 CPU @ 3.60GHz と NVIDIA GeForce RTX 3080 GPU を使用して 13 分未満で完了します。事前学習済みのネットワークを読み込むには、trainNowfalse に設定します。保存済みのネットワークは次の設定で動作することに注意してください。これらの設定を変更する場合は、trainNowtrue に設定します。

txAntennaSize = [2 2 2 1 1]; % rows, columns, polarizations, panels
rxAntennaSize = [2 1 1 1 1]; % rows, columns, polarizations, panels
rmsDelaySpread = 300e-9;     % s
maxDoppler = 5;              % Hz
nSizeGrid = 52;              % Number resource blocks (RB)
                             % 12 subcarriers per RB
subcarrierSpacing = 15; 
trainNow = false;

miniBatchSize = 1000;
trainOptions = trainingOptions("adam", ...
InitialLearnRate=0.01, ...
LearnRateSchedule="piecewise", ...
LearnRateDropPeriod=138, ...
LearnRateDropFactor=0.7456, ...
Epsilon=1e-7, ...
MaxEpochs=1000, ...
MiniBatchSize=miniBatchSize, ...
Shuffle="every-epoch", ...
ValidationData={inputDataV,inputDataV}, ...
ValidationFrequency=20, ...
ValidationPatience=20, ...
Metrics="rmse", ...
Verbose=true, ...
OutputNetwork="best-validation-loss", ...
ExecutionEnvironment="auto", ...
Plots='none')
trainOptions = 
  TrainingOptionsADAM with properties:

             GradientDecayFactor: 0.9000
                       MaxEpochs: 1000
                InitialLearnRate: 0.0100
               LearnRateSchedule: 'piecewise'
             LearnRateDropFactor: 0.7456
             LearnRateDropPeriod: 138
                   MiniBatchSize: 1000
                         Shuffle: 'every-epoch'
         CheckpointFrequencyUnit: 'epoch'
        PreprocessingEnvironment: 'serial'
                         Verbose: 1
                VerboseFrequency: 50
                  ValidationData: {[28×8×2×400 single]  [28×8×2×400 single]}
             ValidationFrequency: 20
              ValidationPatience: 20
                         Metrics: 'rmse'
             ObjectiveMetricName: 'loss'
            ExecutionEnvironment: 'auto'
                           Plots: 'none'
                       OutputFcn: []
                  SequenceLength: 'longest'
            SequencePaddingValue: 0
        SequencePaddingDirection: 'right'
                InputDataFormats: "auto"
               TargetDataFormats: "auto"
         ResetInputNormalization: 1
       ResetInverseNormalization: 1
                NormalizeTargets: 0
    BatchNormalizationStatistics: 'auto'
                   OutputNetwork: 'best-validation-loss'
                    Acceleration: "auto"
                  CheckpointPath: ''
             CheckpointFrequency: 1
        CategoricalInputEncoding: 'integer'
       CategoricalTargetEncoding: 'auto'
                L2Regularization: 1.0000e-04
         GradientThresholdMethod: 'l2norm'
               GradientThreshold: Inf
      SquaredGradientDecayFactor: 0.9990
                         Epsilon: 1.0000e-07

lossFunc = @(x,t) nmseLossdB(x,t);

ネットワークの入力と出力との間の dB 単位の正規化平均二乗誤差 (NMSE) を学習損失関数として使用し、自己符号化器に最適な重みのセットを見つけます。

if trainNow
  [net,trainInfo] = ...
    trainnet(inputDataT,inputDataT,autoencoderNet,lossFunc,trainOptions); %#ok<UNRCH>
  save("csiTrainedNetwork_" ...
    + string(datetime("now","Format","dd_MM_HH_mm")), ...
    'net','trainInfo','systemParams','dataOptions','trainOptions')
else
  systemParamsCached = systemParams;
load("csiTrainedNetwork202507",'net','trainInfo','systemParams','trainOptions')
  if ~checkSystemCompatibility(systemParams,systemParamsCached)
    error("CSIExample:Missmatch", ...
      "Saved network does not match settings. Set trainNow to true.")
  end
end

学習済みネットワークのテスト

predict (Deep Learning Toolbox)関数を使用してテスト データを処理します。

Hhat = predict(net,inputDataTest);

自己符号化器ネットワークの入力と出力との間のコサイン類似度と NMSE を計算します。コサイン類似度は次のように定義されます。

s=hˆmH.hn(hˆmH.hˆm)(hnH.hn)

ここで、hn は自己符号化器の入力におけるチャネル推定で、hˆn は自己符号化器の出力におけるチャネル推定です。コサイン類似度の詳細については、「Cosine Similarity As a Channel Estimate Quality Metric」の例を参照してください。NMSE は次のように定義されます。

NMSE=E{H-Hˆ22H22}normalized mean square error is equal to the square of the second norm of the difference between autoencoder input and output, divided y the square of the seconf norm of the autoencoder input.

ここで、H は自己符号化器の入力におけるチャネル推定で、Hˆ は自己符号化器の出力におけるチャネル推定です。

cossim = zeros(numTest,1);
nmse = zeros(numTest,1);
for n=1:numTest
    in = inputDataTest(:,:,1,n) + 1i*(inputDataTest(:,:,2,n));
    out = Hhat(:,:,1,n) + 1i*(Hhat(:,:,2,n));

    % Calculate correlation
    cossim(n) = helperComplexCosineSimilarity(in,out);

    % Calculate NMSE
    mse = mean(abs(in-out).^2,'all');
    nmse(n) = 10*log10(mse / mean(abs(in).^2,'all'));
end
figure
tiledlayout(3,1)
nexttile
histogram(abs(cossim),"Normalization","probability")
grid on
title(sprintf("Cosine Similarity Magnitude (Mean = %1.2f)", ...
mean(abs(cossim),'all')))
xlabel("Cosine Similarity Magnitude"); ylabel("PDF")
nexttile
histogram(angle(cossim),"Normalization","probability")
grid on
title(sprintf("Cosine Similarity Angle (Mean = %1.2f)", ...
mean(angle(cossim),'all')))
xlabel("Cosine Similarity Angle"); ylabel("PDF")
nexttile
histogram(nmse,"Normalization","probability")
grid on
title(sprintf("NMSE (Mean NMSE = %1.2f dB)", ...
mean(nmse,'all')))
xlabel("NMSE (dB)"); ylabel("PDF")

Figure contains 3 axes objects. Axes object 1 with title Cosine Similarity Magnitude (Mean = 1.00), xlabel Cosine Similarity Magnitude, ylabel PDF contains an object of type histogram. Axes object 2 with title Cosine Similarity Angle (Mean = 0.00), xlabel Cosine Similarity Angle, ylabel PDF contains an object of type histogram. Axes object 3 with title NMSE (Mean NMSE = -44.57 dB), xlabel NMSE (dB), ylabel PDF contains an object of type histogram.

完全な CSI フィードバック システム

次の図は、CSI フィードバックのチャネル推定に関する全体の処理を示しています。UE によって、CSI-RS 信号を使用して 1 つのスロットのチャネル応答 Hest が推定されます。自己符号化器の符号化器部分を使用して、前処理されたチャネル推定 Htr が符号化され、1 行 Nenc 列の圧縮された配列が生成されます。自己符号化器の復号化器部分によって、圧縮された配列が圧縮解除され、Htrˆ が取得されます。Htrˆ を後処理することで、Hestˆ が生成されます。

End-to-end CSI compression

符号化された配列を取得するには、自己符号化器を符号化器ネットワークと復号化器ネットワークの 2 つの部分に分割します。

[encNet,decNet] = helperCSINetSplitEncoderDecoder(net,"Enc_Sigmoid");
plotNetwork(net,encNet,decNet)

Figure contains 3 axes objects. Axes object 1 with title Autoencoder contains an object of type graphplot. Axes object 2 with title Encoder contains an object of type graphplot. Axes object 3 with title Decoder contains an object of type graphplot.

チャネル推定を生成します。

numFrames = 100;
nRx = prod(systemParams.RxAntennaSize);

Hest = helper3GPPChannelRealizations(...
  numFrames, ...
  channel, ...
  carrier, ...
  UseParallel           = false, ...
  SaveData              = false, ...
  Verbose               = false, ...
  ResetChannelPerFrame  = true, ...
  NumSlotsPerFrame      = 1);

チャネル推定を符号化および復号化します。

codeword = helperCSINetEncode(encNet,Hest,systemParams);
Hhat = helperCSINetDecode(decNet,codeword,systemParams);

完全な CSI フィードバック システムについて、コサイン類似度と NMSE を計算します。

H = squeeze(mean(Hest,2));
nmseE2E = zeros(nRx,numFrames);
cossimE2E = zeros(nRx,numFrames);
for rx=1:nRx
    for n=1:numFrames
        out = Hhat(:,rx,:,n);
        in = H(:,rx,:,n);
        cossimE2E(rx,n) = mean(helperComplexCosineSimilarity(in,out));
        nmseE2E(rx,n) = helperNMSE(in,out);
    end
end
figure
tiledlayout(3,1)
nexttile
histogram(abs(cossimE2E),"Normalization","probability")
grid on
title(sprintf("Complete Cosine Similarity Magnitude (Mean = %1.2f)", ...
mean(abs(cossimE2E),'all')))
xlabel("Cosine Similarity Magnitude"); ylabel("PDF")
nexttile
histogram(angle(cossimE2E),"Normalization","probability")
grid on
title(sprintf("Complete Cosine Similarity Angle (Mean = %1.2f)", ...
mean(angle(cossimE2E),'all')))
xlabel("Cosine Similarity Angle"); ylabel("PDF")
nexttile
histogram(nmseE2E,"Normalization","probability")
grid on
title(sprintf("Complete NMSE (Mean NMSE = %1.2f dB)", ...
mean(nmseE2E,'all')))
xlabel("NMSE (dB)"); ylabel("PDF")

Figure contains 3 axes objects. Axes object 1 with title Complete Cosine Similarity Magnitude (Mean = 0.97), xlabel Cosine Similarity Magnitude, ylabel PDF contains an object of type histogram. Axes object 2 with title Complete Cosine Similarity Angle (Mean = 0.00), xlabel Cosine Similarity Angle, ylabel PDF contains an object of type histogram. Axes object 3 with title Complete NMSE (Mean NMSE = -16.19 dB), xlabel NMSE (dB), ylabel PDF contains an object of type histogram.

量子化されたコードワードの効果

実際のシステムでは、符号化されたコードワードを少ないビット数で量子化しなければなりません。[2,10] ビットの範囲で量子化の影響をシミュレーションします。結果は、6 ビットで単精度のパフォーマンスを近似するのに十分であることを示しています。

CSI compression with autoencoder and quantization

maxVal = 1;
minVal = -1;
idxBits = 1;
nBitsVec = 2:10;
rhoQ = zeros(nRx,numFrames,length(nBitsVec));
nmseQ = zeros(nRx,numFrames,length(nBitsVec));
for numBits = nBitsVec
disp("Running for " + numBits + " bit quantization")

    % Quantize between 0:2^n-1 to get bits
    qCodeword = uencode(double(codeword*2-1), numBits);

    % Get back the floating point, quantized numbers
    codewordRx = (single(udecode(qCodeword,numBits))+1)/2;
    Hhat = helperCSINetDecode(decNet,codewordRx,systemParams);
    H = squeeze(mean(Hest,2));
    for rx=1:nRx
        for n=1:numFrames
            out = Hhat(:,rx,:,n);
            in = H(:,rx,:,n);
            rhoQ(rx,n,idxBits) = helperCSINetCorrelation(in,out);
            nmseQ(rx,n,idxBits) = helperNMSE(in,out);
        end
    end
    idxBits = idxBits + 1;
end
Running for 2 bit quantization
Running for 3 bit quantization
Running for 4 bit quantization
Running for 5 bit quantization
Running for 6 bit quantization
Running for 7 bit quantization
Running for 8 bit quantization
Running for 9 bit quantization
Running for 10 bit quantization
figure
tiledlayout(2,1)
nexttile
plot(nBitsVec,squeeze(mean(rhoQ,[1 2])),'*-')
title("Correlation (Codeword-" + size(codeword,3) + ")")
xlabel("Number of Quantization Bits"); ylabel("\rho")
grid on
nexttile
plot(nBitsVec,squeeze(mean(nmseQ,[1 2])),'*-')
title("NMSE (Codeword-" + size(codeword,3) + ")")
xlabel("Number of Quantization Bits"); ylabel("NMSE (dB)")
grid on

Figure contains 2 axes objects. Axes object 1 with title Correlation (Codeword-64), xlabel Number of Quantization Bits, ylabel \rho contains an object of type line. Axes object 2 with title NMSE (Codeword-64), xlabel Number of Quantization Bits, ylabel NMSE (dB) contains an object of type line.

その他の調査

自己符号化器は、[624 8] の単精度複素チャネル推定配列を、平均相関係数が 0.99 で NMSE が -19.55 dB である [64 1] 単精度配列に圧縮します。6 ビットの量子化を使用すると、必要な CSI フィードバック データは 384 ビットのみとなり、圧縮率は約 800:1 になります。

display("Compression ratio is " + (624*8*32*2)/(64*6) + ":" + 1)
    "Compression ratio is 832:1"

truncationFactor がシステムのパフォーマンスに与える影響を調査します。5G システムのパラメーター、チャネルのパラメーター、符号化されたシンボルの数を変更し、定義されたチャネルに最適な値を見つけます。

チャネル状態情報フィードバックを使用した NR PDSCH のスループットの例では、チャネル状態情報 (CSI) フィードバックを使用して物理ダウンリンク共有チャネル (PDSCH) のパラメーターを調整し、スループットを測定する方法を示しています。CSI フィードバック アルゴリズムを CSI 圧縮自己符号化器に置き換えて、パフォーマンスを比較します。

この例では、CSI の圧縮および再構成のための自己符号化器の設計、学習、および評価を行う方法を示します。タスクに固有のその他のプロセスについては、以下の例を参照してください。

MATLAB 上で動作する PyTorch ベースおよび Keras ベースのニューラル ネットワークの学習とテストの方法についても学ぶことができます。

補助関数

補助関数を調べて、システムの詳細な実装を確認します。

学習データの生成

helper3GPPChannelRealizations

ネットワークの定義と操作

helperCSINetDLNetwork

helperCSINetAddResidualLayers

helperCSINetSplitEncoderDecoder

CSI の処理

helperPreprocess3GPPChannelData

helperPostprocess3GPPChannelData

helperCSINetEncode

helperCSINetDecode

パフォーマンスの測定

helperComplexCosineSimilarity

helperNMSE

付録: 実験マネージャーを使ったハイパーパラメーターの最適化

最適なパラメーターを見つけるには、実験マネージャー アプリを使用します。CSITrainingProject.mlproj は事前構成されたプロジェクトです。プロジェクトを抽出します。

projectName = "CSITrainingProject";
if ~exist(projectName,"dir")
  projRoot = helperCSINetExtractProject(projectName);
else
  projRoot = fullfile(exRoot(),projectName);
end

プロジェクトを開くには、実験マネージャー アプリを起動し、次のファイルを開きます。

disp(fullfile(".","CSITrainingProject","CSITrainingProject.prj"))
.\CSITrainingProject\CSITrainingProject.prj

ハイパーパラメーター最適化時に使用する入力データと自己符号化器オプションを保存します。

dataDir = fullfile(pwd,"Data","processed");
if ~isfolder(dataDir)
  mkdir(dataDir)
end
save(fullfile(dataDir,"nr_channel_preprocessed.mat"), ...
    "inputData","systemParams")
save('data_folder.mat', "dataDir");

ハイパーパラメーターの最適化実験では、次の図のように、ハイパーパラメーターの検索範囲が指定されたベイズ最適化が使用されます。プロジェクトを開いたら、実験設定関数 CSIAutoEncNN_setup を使用できます。カスタム メトリクス関数は E2E_NMSE です。

ExperimentSetup.png

最適なパラメーターは、初期学習率が 0.01、学習率低下期間が 156 回の反復、学習率低下係数が 0.5916 です。最適なハイパーパラメーターを見つけたら、同じパラメーターを使ってネットワークを複数回学習させ、学習結果が最も良いネットワークを見つけます。

ExperimentSetp2.png

9 回目の試行で最適な E2E_NMSE が得られました。この例では、この学習済みのネットワークを保存済みのネットワークとして使用します。

ExperimentResults2.png

バッチ モードの構成

実行の [モード]Batch Sequential または Batch Simultaneous に設定されている場合、学習データは、[Prepare Data in Bulk] セクションの dataDir 変数によって定義された場所にあるワーカーにアクセスできなければなりません。dataDir を、ワーカーがアクセスできるネットワークの場所に設定してください。詳細については、Offload Experiments as Batch Jobs to a Cluster (Deep Learning Toolbox)を参照してください。

ローカル関数

function [inputData,systemParams,dataOptions,channel,carrier] = prepareData(numSamples)
carrier = nrCarrierConfig;
nSizeGrid = 52;                                         % Number resource blocks (RB)
systemParams.SubcarrierSpacing = 15;  % 15, 30, 60, 120 kHz
carrier.NSizeGrid = nSizeGrid;
carrier.SubcarrierSpacing = systemParams.SubcarrierSpacing;
waveInfo = nrOFDMInfo(carrier);
systemParams.TxAntennaSize = [2 2 2 1 1];   % rows, columns, polarization, panels
systemParams.RxAntennaSize = [2 1 1 1 1];   % rows, columns, polarization, panels
systemParams.MaxDoppler = 5;                % Hz
systemParams.RMSDelaySpread = 300e-9;       % s
systemParams.DelayProfile = "CDL-C"; % CDL-A, CDL-B, CDL-C, CDL-D, CDL-D, CDL-E
systemParams.NumSubcarriers = carrier.NSizeGrid*12;
channel = nrCDLChannel;
channel.DelayProfile = systemParams.DelayProfile;
channel.DelaySpread = systemParams.RMSDelaySpread;     % s
channel.MaximumDopplerShift = systemParams.MaxDoppler; % Hz
channel.RandomStream = "Global stream";
channel.TransmitAntennaArray.Size = systemParams.TxAntennaSize;
channel.ReceiveAntennaArray.Size = systemParams.RxAntennaSize;
channel.ChannelFiltering = false;
channel.SampleRate = waveInfo.SampleRate;
samplesPerSlot = ...
  sum(waveInfo.SymbolLengths(1:waveInfo.SymbolsPerSlot));
channel.NumTimeSamples = samplesPerSlot; % 1 slot worth of samples
systemParams.NumSymbols = 14;
useParallel = true;
saveData =  true;
dataDir = fullfile(pwd,"Data");
dataFilePrefix = "nr_channel_est";
numSlotsPerFrame = 1;
resetChannel = true;
sdsChan = helper3GPPChannelRealizations(...
  numSamples, ...
  channel, ...
  carrier, ...
  UseParallel=useParallel, ...
  SaveData=saveData, ...
  DataDir=dataDir, ...
  dataFilePrefix=dataFilePrefix, ...
  NumSlotsPerFrame=numSlotsPerFrame, ...
  ResetChannelPerFrame=resetChannel);

dataOptions.DataDomain = "Frequency-Spatial (FS)";
dataOptions.TruncationFactor = 10;
Tdelay = 1/(systemParams.NumSubcarriers*carrier.SubcarrierSpacing*1e3);
rmsDelaySpreadSamples = channel.DelaySpread/Tdelay;
[data,dataOptions] = helperPreprocess3GPPChannelData( ...
  sdsChan, ...
  TrainingObjective          = "autoencoding", ...
  AverageOverSlots           = true, ...
  TruncateChannel            = true, ...
  ExpectedDelaySpreadSamples = rmsDelaySpreadSamples, ...
  TruncationFactor           = dataOptions.TruncationFactor, ...
  DataComplexity             = "real (2D)", ...
  IQDimension                = 3, ...
  DataDomain                 = dataOptions.DataDomain, ...
  UseParallel                = useParallel, ...
  SaveData                   = false);
meanVal = mean(data{1},'all');
stdVal = std(data{1},[],'all');
inputData = (data{1}-meanVal) / stdVal;
targetStd = 0.0212;
inputData = inputData*targetStd+0.5;
systemParams.Normalization = "mean-variance";
systemParams.MeanValue = meanVal;
systemParams.StandardDeviationValue = stdVal;
systemParams.TargetStandardDeviation = targetStd;
systemParams.ExpectedDelaySpreadSamples = dataOptions.ExpectedDelaySpreadSamples;
end

function compatible = checkSystemCompatibility(systemParams,systemParamsCached)
compatible = false;
if systemParams.SubcarrierSpacing ~= systemParamsCached.SubcarrierSpacing
  return
end
if any(systemParams.TxAntennaSize ~= systemParamsCached.TxAntennaSize)
  return
end
if any(systemParams.RxAntennaSize ~= systemParamsCached.RxAntennaSize)
  return
end
if systemParams.MaxDoppler ~= systemParamsCached.MaxDoppler
  return
end
if systemParams.RMSDelaySpread ~= systemParamsCached.RMSDelaySpread
  return
end
if systemParams.DelayProfile ~= systemParamsCached.DelayProfile
  return
end
if systemParams.NumSubcarriers ~= systemParamsCached.NumSubcarriers
  return
end
if systemParams.NumSymbols ~= systemParamsCached.NumSymbols
  return
end
if abs(systemParams.MeanValue - systemParamsCached.MeanValue) > 3e-2
  return
end
if (systemParams.StandardDeviationValue - systemParamsCached.StandardDeviationValue) > 3e-2
  return
end
if systemParams.TargetStandardDeviation ~= systemParamsCached.TargetStandardDeviation
  return
end
if systemParams.ExpectedDelaySpreadSamples ~= systemParamsCached.ExpectedDelaySpreadSamples
  return
end
if systemParams.MaxDelay ~= systemParamsCached.MaxDelay
  return
end
compatible = true;
end

function plotNetwork(net,encNet,decNet)
%plotNetwork Plot autoencoder network
%   plotNetwork(NET,ENC,DEC) plots the full autoencoder network together
%   with encoder and decoder networks.
fig = figure;
t1 = tiledlayout(1,2,'TileSpacing','Compact');
t2 = tiledlayout(t1,1,1,'TileSpacing','Tight');
t3 = tiledlayout(t1,2,1,'TileSpacing','Tight');
t3.Layout.Tile = 2;
nexttile(t2)
plot(net)
title("Autoencoder")
nexttile(t3)
plot(encNet)
title("Encoder")
nexttile(t3)
plot(decNet)
title("Decoder")
pos = fig.Position;
pos(3) = pos(3) + 200;
pos(4) = pos(4) + 300;
pos(2) = pos(2) - 300;
fig.Position = pos;
end

function rootDir = exRoot()
%exRoot Example root directory
rootDir = fileparts(which("helperCSINetDLNetwork"));
end

function loss = nmseLossdB(x,xHat)
%nmseLossdB NMSE loss in dB
in = complex(x(:,:,1,:),x(:,:,2,:));
out = complex(xHat(:,:,1,:),xHat(:,:,2,:));
nmsePerObservation = helperNMSE(in,out);
loss = mean(nmsePerObservation);
end

参考文献

[1] Wen, Chao-Kai, Wan-Ting Shih, and Shi Jin. “Deep Learning for Massive MIMO CSI Feedback.” IEEE Wireless Communications Letters 7, no. 5 (October 2018): 748–51. https://doi.org/10.1109/LWC.2018.2818160.

参考

トピック