メインコンテンツ

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

自己符号化器を使った CSI フィードバック

この例では、自己符号化器ニューラル ネットワークを使用し、クラスター遅延線 (CDL) チャネルによってダウンリンク チャネル状態情報 (CSI) を圧縮する方法を示します。CSI フィードバックは、生のチャネル推定配列の形式になります。

はじめに

従来の 5G 無線ネットワークにおける CSI パラメーターは、チャネル推定配列から抽出されるチャネルの状態に関連する量です。CSI フィードバックには、チャネル品質指標 (CQI)、さまざまなコードブック セットをもつプリコーディング行列インデックス (PMI)、ランク インジケーター (RI) などのいくつかのパラメーターが含まれます。UE は、CSI 基準信号 (CSI-RS) を使用して CSI パラメーターを測定および計算します。ユーザー端末 (UE) は、CSI パラメーターをアクセス ネットワーク ノード (gNB) にフィードバックとして報告します。gNB は、CSI パラメーターを受信すると、変調スキーム、符号化率、トランスミッション レイヤーの数、MIMO プリコーディングなどの属性を使ってダウンリンク データ伝送をスケジュールします。次の図は、CSI-RS の送信、CSI フィードバック、および CSI パラメーターに基づいてスケジュールされたダウンリンク データ伝送の概要を示しています。

UE は、チャネル推定を処理して CSI フィードバック データの量を削減します。UE は、代わりの手法として、チャネル推定配列を圧縮してフィードバックします。受信後、gNB はチャネル推定値を解凍して処理し、ダウンリンクのデータ リンク パラメーターを決定します。圧縮と解凍は、自己符号化器ニューラル ネットワーク [12] を使用して実現できます。この手法を使用すると、従来の量子化コードブックを使用する必要がなくなり、システム全体のパフォーマンスが向上します。

次の例では、これらのシステム パラメーターをもつ 5G ダウンリンク チャネルを使用します。

txAntennaSize = [2 2 2 1 1]; % rows, columns, polarizations, panels
rxAntennaSize = [2 1 1 1 1]; % rows, columns, polarizations, panels
rmsDelaySpread = 300e-9;     % s
maxDoppler = 5;              % Hz
nSizeGrid = 52;              % Number resource blocks (RB)
                             % 12 subcarriers per RB
subcarrierSpacing = 15;      % 15, 30, 60, 120 kHz
numTrainingChEst = 15000;

% Carrier definition
carrier = nrCarrierConfig;
carrier.NSizeGrid = nSizeGrid;
carrier.SubcarrierSpacing = subcarrierSpacing
carrier = 
  nrCarrierConfig with properties:

                NCellID: 1
      SubcarrierSpacing: 15
           CyclicPrefix: 'normal'
              NSizeGrid: 52
             NStartGrid: 0
                  NSlot: 0
                 NFrame: 0
    IntraCellGuardBands: [0×2 double]

   Read-only properties:
         SymbolsPerSlot: 14
       SlotsPerSubframe: 1
          SlotsPerFrame: 10

autoEncOpt.NumSubcarriers = carrier.NSizeGrid*12;
autoEncOpt.NumSymbols = carrier.SymbolsPerSlot;
autoEncOpt.NumTxAntennas = prod(txAntennaSize);
autoEncOpt.NumRxAntennas = prod(rxAntennaSize);

データの生成と前処理

AI ベースのシステムを設計する最初のステップは、学習データとテスト データを準備することです。この例では、シミュレートされたチャネル推定を生成し、データを前処理します。5G Toolbox™ の関数を使用して、CDL-C チャネルを構成します。データ生成の詳細については、Prepare Data for CSI Processingの例を参照してください。CDL-C チャネルを定義します。

channel = nrCDLChannel;
channel.DelayProfile = 'CDL-C';
channel.DelaySpread = rmsDelaySpread;       % s
channel.MaximumDopplerShift = maxDoppler;   % Hz
channel.RandomStream = "Global stream";
channel.TransmitAntennaArray.Size = txAntennaSize;
channel.ReceiveAntennaArray.Size = rxAntennaSize;
channel.ChannelFiltering = false;           % No filtering for 

データセットに対して生成されるサンプルの数は、以下のように設定できます。実行時間を短縮するには、サンプル数を 1500 に設定します。保存済みの結果では 15000 個のサンプルが使用されています。

numSamples = 1500;

前処理済みデータの準備、打ち切り係数、およびタイミング オフセットのドメインを選択します。

autoEncOpt.DataDomain       = "Frequency-Spatial";
autoEncOpt.TruncationFactor = 10;
autoEncOpt.ZeroTimingOffset = true;

Parallel Computing Toolbox™ を利用できる場合は、autoEncOpt.UseParallel 変数を true に設定して並列データ生成を有効にします。Intel® Xeon® W-2133 CPU @ 3.60GHz を搭載した PC を使用し、6 つのワーカーで並列実行する場合、15000 サンプルのデータ生成に約 6 分かかります。

autoEncOpt.UseParallel = true;

autoEncOpt.SaveData を有効にし、前処理されたチャネル推定を .mat ファイルに保存します。

autoEncOpt.SaveData = true;
autoEncOpt.DataDir = "Data";
autoEncOpt.DataFilePrefix = "CH_est";

サンプルの生成

helperCSINetGenerateData 補助関数は、Prepare Data for CSI Processingの例で説明されているプロセスを使用して、'numSamples' 個の前処理済みの [Ndelay Ntx 2] チャネル推定を生成します。saveOptions.SaveData を有効にした場合、この関数は、各 [Ndelay Ntx 2 Nrx] チャネル推定を saveOptions.DataFilePrefix. というプレフィックスをもつ個別のファイルとして saveOptions.DataDir に保存します。

[HtruncReal,autoEncOpt] = helperCSINetGenerateData(numSamples,channel,carrier,autoEncOpt);
Starting CSI data generation
6 worker(s) running
00:00:30 -  0% Completed
00:00:37 -  0% Completed
00:00:38 -  0% Completed
00:00:50 -  0% Completed
00:00:50 -  0% Completed
00:00:53 -  0% Completed
00:00:58 -  0% Completed
00:01:05 - 100% Completed

サンプルの前処理

HtruncReal 変数には Nframes 個のフレームが含まれています。各フレームには、独立した Nrx 個の受信アンテナのデータが含まれています。

[maxDelay,nTx,Niq,nRx,Nframes] = size(HtruncReal)
maxDelay = 
28
nTx = 
8
Niq = 
2
nRx = 
2
Nframes = 
750

フレームとアンテナを組み合わせます。次に、平均値と標準偏差を計算し、平均と標準偏差の値を使用してデータを正規化します。

HtruncReal = reshape(HtruncReal,maxDelay,nTx,Niq,nRx*Nframes);
meanVal = mean(HtruncReal,'all')
meanVal = single

-2.5427e-04
stdVal = std(HtruncReal,[],'all')
stdVal = single

16.1309

データを学習セット、検証セット、テスト セットに分割します。また、データを正規化し、ゼロ平均と目標標準偏差 0.0212 を実現します。これにより、データの大部分が [-0.5 0.5] の範囲に制限されます。

N = size(HtruncReal, 4);
numTrain = floor(N*10/15)
numTrain = 
1000
numVal = floor(N*3/15)
numVal = 
300
numTest = floor(N*2/15)
numTest = 
200
targetStd = 0.0212;
HTReal = (HtruncReal(:,:,:,1:numTrain)-meanVal) ...
  /stdVal*targetStd+0.5;
HVReal = (HtruncReal(:,:,:,numTrain+(1:numVal))-meanVal) ...
  /stdVal*targetStd+0.5;
HTestReal = (HtruncReal(:,:,:,numTrain+numVal+(1:numTest))-meanVal) ...
  /stdVal*targetStd+0.5;
autoEncOpt.MeanVal = meanVal;
autoEncOpt.StdValue = stdVal;
autoEncOpt.TargetSTDValue = targetStd;  

ニューラル ネットワーク モデルの定義と学習

AI ベースのシステムを設計する 2 番目のステップは、ニューラル ネットワーク モデルを定義して学習させることです。

ニューラル ネットワークの定義

この例では、[1] で推奨されている自己符号化器ニューラル ネットワークの修正版を使用します。

inputSize = [autoEncOpt.MaxDelay nTx 2];  % Third dimension is real and imaginary parts
nLinear = prod(inputSize);
nEncoded = 64;

autoencoderNet = dlnetwork([ ...
    % Encoder
    imageInputLayer(inputSize,"Name","Htrunc", ...
      "Normalization","none","Name","Enc_Input")

    convolution2dLayer([3 3],2,"Padding","same","Name","Enc_Conv")
    batchNormalizationLayer("Epsilon",0.001,"Name","Enc_BN")
    leakyReluLayer(0.3,"Name","Enc_leakyRelu")

    flattenLayer("Name","Enc_flatten")

    fullyConnectedLayer(nEncoded,"Name","Enc_FC")

    sigmoidLayer("Name","Enc_Sigmoid")

    % Decoder
    fullyConnectedLayer(nLinear,"Name","Dec_FC")

    functionLayer(@(x)dlarray(reshape(x,maxDelay,nTx,2,[]),'SSCB'), ...
      "Formattable",true,"Acceleratable",true,"Name","Dec_Reshape")
    ]);

autoencoderNet = ...
  helperCSINetAddResidualLayers(autoencoderNet, "Dec_Reshape");

autoencoderNet = addLayers(autoencoderNet, ...
    [convolution2dLayer([3 3],2,"Padding","same","Name","Dec_Conv") ...
    sigmoidLayer("Name","Dec_Sigmoid")]);
autoencoderNet = ...
  connectLayers(autoencoderNet,"leakyRelu_2_3","Dec_Conv");

figure
plot(autoencoderNet)
title('CSI Compression Autoencoder')

Figure contains an axes object. The axes object with title CSI Compression Autoencoder contains an object of type graphplot.

ニューラル ネットワークの学習

自己符号化器ニューラル ネットワークの学習オプションを設定し、trainnet (Deep Learning Toolbox)関数を使用してネットワークに学習させます。学習は、Intel® Xeon® W-2133 CPU @ 3.60GHz、compute capacity 7.0 の NVIDIA® TITAN V GPU、12 GB のメモリを使用して、3 分未満で完了します。事前学習済みのネットワークを読み込むには、trainNowfalse に設定します。保存済みのネットワークは次の設定で動作することに注意してください。これらの設定を変更する場合は、trainNowtrue に設定します。

txAntennaSize = [2 2 2 1 1]; % rows, columns, polarizations, panels
rxAntennaSize = [2 1 1 1 1]; % rows, columns, polarizations, panels
rmsDelaySpread = 300e-9;     % s
maxDoppler = 5;              % Hz
nSizeGrid = 52;              % Number resource blocks (RB)
                             % 12 subcarriers per RB
subcarrierSpacing = 15; 
trainNow = false;

miniBatchSize = 1000;
options = trainingOptions("adam", ...
    InitialLearnRate=0.01, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropPeriod=156, ...
    LearnRateDropFactor=0.5916, ...
    Epsilon=1e-7, ...
    MaxEpochs=1000, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    ValidationData={HVReal,HVReal}, ...
    ValidationFrequency=20, ...
    Metrics="rmse", ...
    Verbose=true, ...
    ValidationPatience=20, ...
    OutputNetwork="best-validation-loss", ...
    ExecutionEnvironment="auto", ...
    Plots='training-progress') 
options = 
  TrainingOptionsADAM with properties:

             GradientDecayFactor: 0.9000
                       MaxEpochs: 1000
                InitialLearnRate: 0.0100
               LearnRateSchedule: 'piecewise'
             LearnRateDropFactor: 0.5916
             LearnRateDropPeriod: 156
                   MiniBatchSize: 1000
                         Shuffle: 'every-epoch'
         CheckpointFrequencyUnit: 'epoch'
        PreprocessingEnvironment: 'serial'
                         Verbose: 1
                VerboseFrequency: 50
                  ValidationData: {[28×8×2×300 single]  [28×8×2×300 single]}
             ValidationFrequency: 20
              ValidationPatience: 20
                         Metrics: 'rmse'
             ObjectiveMetricName: 'loss'
            ExecutionEnvironment: 'auto'
                           Plots: 'training-progress'
                       OutputFcn: []
                  SequenceLength: 'longest'
            SequencePaddingValue: 0
        SequencePaddingDirection: 'right'
                InputDataFormats: "auto"
               TargetDataFormats: "auto"
         ResetInputNormalization: 1
    BatchNormalizationStatistics: 'auto'
                   OutputNetwork: 'best-validation-loss'
                    Acceleration: "auto"
                  CheckpointPath: ''
             CheckpointFrequency: 1
        CategoricalInputEncoding: 'integer'
       CategoricalTargetEncoding: 'auto'
                L2Regularization: 1.0000e-04
         GradientThresholdMethod: 'l2norm'
               GradientThreshold: Inf
      SquaredGradientDecayFactor: 0.9990
                         Epsilon: 1.0000e-07

  Show all accessible properties of TrainingOptionsADAM

lossFunc = @(x,t) nmseLossdB(x,t);

ネットワークの入力と出力との間の dB 単位の正規化平均二乗誤差 (NMSE) を学習損失関数として使用し、自己符号化器に最適な重みのセットを見つけます。

if trainNow
  [net,trainInfo] = ...
    trainnet(HTReal,HTReal,autoencoderNet,lossFunc,options); %#ok<UNRCH> 
  savedOptions = options;
  savedOptions.ValidationData = [];
  save("dCSITrainedNetwork_" ...
    + string(datetime("now","Format","dd_MM_HH_mm")), ...
    'net','trainInfo','autoEncOpt','savedOptions')
else
  autoEncOptCached = autoEncOpt;
  load("dCSITrainedNetwork",'net','trainInfo','autoEncOpt','savedOptions')
  if autoEncOpt.NumSubcarriers ~= autoEncOptCached.NumSubcarriers ...
      || autoEncOpt.NumSymbols ~= autoEncOptCached.NumSymbols ...
      || autoEncOpt.NumTxAntennas ~= autoEncOptCached.NumTxAntennas ...
      || autoEncOpt.NumRxAntennas ~= autoEncOptCached.NumRxAntennas ...
      || autoEncOpt.MaxDelay ~= autoEncOptCached.MaxDelay
    error("CSIExample:Missmatch", ...
      "Saved network does not match settings. Set trainNow to true.")
  end
end

学習済みネットワークのテスト

predict (Deep Learning Toolbox)関数を使用してテスト データを処理します。

HTestRealHat = predict(net,HTestReal);

自己符号化器ネットワークの入力と出力との間の相関と NMSE を計算します。相関は次のように定義されます。

ρ=E{1Nn=1N|hˆnHhn|hˆn2hn2}

ここで、hn は自己符号化器の入力におけるチャネル推定で、hˆn は自己符号化器の出力におけるチャネル推定です。NMSE は次のように定義されます。

NMSE=E{H-Hˆ22H22}normalized mean square error is equal to the square of the second norm of the difference between autoencoder input and output, divided y the square of the seconf norm of the autoencoder input.

ここで、H は自己符号化器の入力におけるチャネル推定で、Hˆ は自己符号化器の出力におけるチャネル推定です。

rho = zeros(numTest,1);
nmse = zeros(numTest,1);
for n=1:numTest
    in = HTestReal(:,:,1,n) + 1i*(HTestReal(:,:,2,n));
    out = HTestRealHat(:,:,1,n) + 1i*(HTestRealHat(:,:,2,n));

    % Calculate correlation
    n1 = sqrt(sum(conj(in).*in,'all'));
    n2 = sqrt(sum(conj(out).*out,'all'));
    aa = abs(sum(conj(in).*out,'all'));
    rho(n) = aa / (n1*n2);

    % Calculate NMSE
    mse = mean(abs(in-out).^2,'all');
    nmse(n) = 10*log10(mse / mean(abs(in).^2,'all'));
end

figure
tiledlayout(2,1)
nexttile
histogram(rho,"Normalization","probability")
grid on
title(sprintf("Autoencoder Correlation (Mean \\rho = %1.5f)", ...
  mean(rho)))
xlabel("\rho"); ylabel("PDF")
nexttile
histogram(nmse,"Normalization","probability")
grid on
title(sprintf("Autoencoder NMSE (Mean NMSE = %1.2f dB)",mean(nmse)))
xlabel("NMSE (dB)"); ylabel("PDF")

Figure contains 2 axes objects. Axes object 1 with title Autoencoder Correlation (Mean blank rho blank = blank 0 . 99999 ), xlabel \rho, ylabel PDF contains an object of type histogram. Axes object 2 with title Autoencoder NMSE (Mean NMSE = -46.41 dB), xlabel NMSE (dB), ylabel PDF contains an object of type histogram.

エンドツーエンドの CSI フィードバック システム

次の図は、CSI フィードバックのチャネル推定に関するエンドツーエンドの処理を示しています。UE によって、CSI-RS 信号を使用して 1 つのスロットのチャネル応答 Hest が推定されます。自己符号化器の符号化器部分を使用して、前処理されたチャネル推定 Htr が符号化され、1 行 Nenc 列の圧縮された配列が生成されます。自己符号化器の復号化器部分によって、圧縮された配列が圧縮解除され、Htrˆ が取得されます。Htrˆ を後処理することで、Hestˆ が生成されます。

End-to-end CSI compression

符号化された配列を取得するには、自己符号化器を符号化器ネットワークと復号化器ネットワークの 2 つの部分に分割します。

[encNet,decNet] = helperCSINetSplitEncoderDecoder(net,"Enc_Sigmoid");
plotNetwork(net,encNet,decNet)

Figure contains 3 axes objects. Axes object 1 with title Autoencoder contains an object of type graphplot. Axes object 2 with title Encoder contains an object of type graphplot. Axes object 3 with title Decoder contains an object of type graphplot.

100 スロットのチャネル推定を生成します。各フレームには 1 つのスロットが含まれ、各フレームの後でチャネルがリセットされます。

numFrames = 100;
[autoEncOpt,channel] = addSimOptions(autoEncOpt,channel,carrier);
Hest = helperCSIGenerateData(numFrames,channel,carrier,autoEncOpt);

Normalizationtrue に設定し、チャネル推定を符号化および復号化します。

autoEncOpt.Normalization = true;
codeword = helperCSINetEncode(encNet, Hest, autoEncOpt);
Hhat = helperCSINetDecode(decNet, codeword, autoEncOpt);

エンドツーエンドの CSI フィードバック システムの相関と NMSE を計算します。

H = squeeze(mean(Hest,2));
rhoE2E = zeros(nRx,numFrames);
nmseE2E = zeros(nRx,numFrames);
for rx=1:nRx
    for n=1:numFrames
        out = Hhat(:,rx,:,n);
        in = H(:,rx,:,n);
        rhoE2E(rx,n) = helperCSINetCorrelation(in,out);
        nmseE2E(rx,n) = helperNMSE(in,out);
    end
end
figure
tiledlayout(2,1)
nexttile
histogram(rhoE2E,"Normalization","probability")
grid on
title(sprintf("End-to-End Correlation (Mean \\rho = %1.5f)", ...
  mean(rhoE2E,'all')))
xlabel("\rho"); ylabel("PDF")
nexttile
histogram(nmseE2E,"Normalization","probability")
grid on
title(sprintf("End-to-End NMSE (Mean NMSE = %1.2f dB)", ...
  mean(nmseE2E,'all')))
xlabel("NMSE (dB)"); ylabel("PDF")

Figure contains 2 axes objects. Axes object 1 with title End-to-End blank Correlation blank (Mean blank rho blank = blank 0 . 99053 ), xlabel \rho, ylabel PDF contains an object of type histogram. Axes object 2 with title End-to-End NMSE (Mean NMSE = -17.56 dB), xlabel NMSE (dB), ylabel PDF contains an object of type histogram.

量子化されたコードワードの効果

実際のシステムでは、符号化されたコードワードを少ないビット数で量子化しなければなりません。[2, 10] ビットの範囲にわたって量子化の効果をシミュレートします。結果は、6 ビットで単精度のパフォーマンスを近似するのに十分であることを示しています。

CSI compression with autoencoder and quantization

maxVal = 1;
minVal = -1;
idxBits = 1;
nBitsVec = 2:10;
rhoQ = zeros(nRx,numFrames,length(nBitsVec));
nmseQ = zeros(nRx,numFrames,length(nBitsVec));
for numBits = nBitsVec
    disp("Running for " + numBits + " bit quantization")

    % Quantize between 0:2^n-1 to get bits
    qCodeword = uencode(double(codeword*2-1), numBits);

    % Get back the floating point, quantized numbers
    codewordRx = (single(udecode(qCodeword,numBits))+1)/2;
    Hhat = helperCSINetDecode(decNet, codewordRx, autoEncOpt);
    H = squeeze(mean(Hest,2));
    for rx=1:nRx
        for n=1:numFrames
            out = Hhat(:,rx,:,n);
            in = H(:,rx,:,n);
            rhoQ(rx,n,idxBits) = helperCSINetCorrelation(in,out);
            nmseQ(rx,n,idxBits) = helperNMSE(in,out);
        end
    end
    idxBits = idxBits + 1;
end
Running for 2 bit quantization
Running for 3 bit quantization
Running for 4 bit quantization
Running for 5 bit quantization
Running for 6 bit quantization
Running for 7 bit quantization
Running for 8 bit quantization
Running for 9 bit quantization
Running for 10 bit quantization
figure
tiledlayout(2,1)
nexttile
plot(nBitsVec,squeeze(mean(rhoQ,[1 2])),'*-')
title("Correlation (Codeword-" + size(codeword,3) + ")")
xlabel("Number of Quantization Bits"); ylabel("\rho")
grid on
nexttile
plot(nBitsVec,squeeze(mean(nmseQ,[1 2])),'*-')
title("NMSE (Codeword-" + size(codeword,3) + ")")
xlabel("Number of Quantization Bits"); ylabel("NMSE (dB)")
grid on

Figure contains 2 axes objects. Axes object 1 with title Correlation (Codeword-64), xlabel Number of Quantization Bits, ylabel \rho contains an object of type line. Axes object 2 with title NMSE (Codeword-64), xlabel Number of Quantization Bits, ylabel NMSE (dB) contains an object of type line.

その他の調査

自己符号化器は、[624 8] の単精度複素チャネル推定配列を、平均相関係数が 0.99 で NMSE が -19.55 dB である [64 1] 単精度配列に圧縮することができます。6 ビットの量子化を使用すると、必要な CSI フィードバック データは 384 ビットのみとなり、圧縮率は約 800:1 になります。

display("Compression ratio is " + (624*8*32*2)/(64*6) + ":" + 1)
    "Compression ratio is 832:1"

truncationFactor がシステムのパフォーマンスに与える影響を調査します。5G システムのパラメーター、チャネルのパラメーター、符号化されたシンボルの数を変更し、定義されたチャネルに最適な値を見つけます。

チャネル状態情報フィードバックを使用した NR PDSCH のスループットの例では、チャネル状態情報 (CSI) フィードバックを使用して物理ダウンリンク共有チャネル (PDSCH) のパラメーターを調整し、スループットを測定する方法を示しています。CSI フィードバック アルゴリズムを CSI 圧縮自己符号化器に置き換えて、パフォーマンスを比較します。

補助関数

補助関数を調べて、システムの詳細な実装を確認します。

学習データの生成

helperCSINetGenerateData

helperCSIGenerateData

helperCSIChannelEstimate

ネットワークの定義と操作

helperCSINetDLNetwork

helperCSINetAddResidualLayers

helperCSINetSplitEncoderDecoder

CSI の処理

helperCSIPreprocessChannelEstimate

helperCSINetPostprocessChannelEstimate

helperCSINetEncode

helperCSINetDecode

パフォーマンスの測定

helperCSINetCorrelation

helperNMSE

付録: 実験マネージャーを使ったハイパーパラメーターの最適化

最適なパラメーターを見つけるには、実験マネージャー アプリを使用します。CSITrainingProject.mlproj は事前構成されたプロジェクトです。プロジェクトを抽出します。

projectName = "CSITrainingProject";
if ~exist(projectName,"dir")
  projRoot = helperCSINetExtractProject(projectName);
else
  projRoot = fullfile(exRoot(),projectName);
end

プロジェクトを開くには、実験マネージャー アプリを起動し、次のファイルを開きます。

disp(fullfile(".","CSITrainingProject","CSITrainingProject.prj"))
.\CSITrainingProject\CSITrainingProject.prj

ハイパーパラメーターの最適化実験では、次の図のように、ハイパーパラメーターの検索範囲が指定されたベイズ最適化が使用されます。プロジェクトを開いたら、実験設定関数 CSIAutoEncNN_setup を使用できます。カスタム メトリクス関数は E2E_NMSE です。

ExperimentSetup.png

ExperimentResults.png

最適なパラメーターは、初期学習率が 0.01、学習率低下期間が 156 回の反復、学習率低下係数が 0.5916 です。最適なハイパーパラメーターを見つけたら、同じパラメーターを使ってネットワークを複数回学習させ、学習結果が最も良いネットワークを見つけます。

ExperimentSetp2.png

9 回目の試行で最適な E2E_NMSE が得られました。この例では、この学習済みのネットワークを保存済みのネットワークとして使用します。

ExperimentResults2.png

バッチ モードの構成

実行の [モード]Batch Sequential または Batch Simultaneous に設定されている場合、学習データは、[Prepare Data in Bulk] セクションの dataDir 変数によって定義された場所にあるワーカーにアクセスできなければなりません。dataDir を、ワーカーがアクセスできるネットワークの場所に設定してください。詳細については、Offload Experiments as Batch Jobs to a Cluster (Deep Learning Toolbox)を参照してください。

ローカル関数

function plotNetwork(net,encNet,decNet)
%plotNetwork Plot autoencoder network
%   plotNetwork(NET,ENC,DEC) plots the full autoencoder network together
%   with encoder and decoder networks.
fig = figure;
t1 = tiledlayout(1,2,'TileSpacing','Compact');
t2 = tiledlayout(t1,1,1,'TileSpacing','Tight');
t3 = tiledlayout(t1,2,1,'TileSpacing','Tight');
t3.Layout.Tile = 2;
nexttile(t2)
plot(net)
title("Autoencoder")
nexttile(t3)
plot(encNet)
title("Encoder")
nexttile(t3)
plot(decNet)
title("Decoder")
pos = fig.Position;
pos(3) = pos(3) + 200;
pos(4) = pos(4) + 300;
pos(2) = pos(2) - 300;
fig.Position = pos;
end

function rootDir = exRoot()
%exRoot Example root directory
rootDir = fileparts(which("helperCSINetDLNetwork"));
end

function loss = nmseLossdB(x,xHat)
%nmseLossdB NMSE loss in dB
in = complex(x(:,:,1,:),x(:,:,2,:));
out = complex(xHat(:,:,1,:),xHat(:,:,2,:));
nmsePerObservation = helperNMSE(in,out);
loss = mean(nmsePerObservation);
end

function [opt,channel] = addSimOptions(opt,channel,carrier)
opt.SaveData = false;
opt.Preprocess = false;
if isa(channel,"nrCDLChannel")
  % Make sure that this is high enough for nrPerfectChannelEstimate to return
  % the full number of symbols worth of channel estimates
  opt.ChannelSampleDensity = 64*4;
end

waveInfo = nrOFDMInfo(carrier);
channel.SampleRate = waveInfo.SampleRate;

numSubCarriers = carrier.NSizeGrid*12; % 12 subcarriers per RB
Tdelay = 1/(numSubCarriers*carrier.SubcarrierSpacing*1e3);
opt.MaxDelay = round((channel.DelaySpread/Tdelay)*opt.TruncationFactor/2)*2;

opt.NumSlotsPerFrame = 1;
opt.Preprocess = false;
opt.ResetChannelPerFrame = true;
opt.Normalization = false;
opt.Verbose = false;
end

参考文献

[1] Wen, Chao-Kai, Wan-Ting Shih, and Shi Jin. “Deep Learning for Massive MIMO CSI Feedback.” IEEE Wireless Communications Letters 7, no. 5 (October 2018): 748–51. https://doi.org/10.1109/LWC.2018.2818160.

[2] Zimaglia, Elisa, Daniel G. Riviello, Roberto Garello, and Roberto Fantini. “A Novel Deep Learning Approach to CSI Feedback Reporting for NR 5G Cellular Systems.” In 2020 IEEE Microwave Theory and Techniques in Wireless Communications (MTTW), 47–52. Riga, Latvia: IEEE, 2020. https://doi.org/10.1109/MTTW51045.2020.9245055.

参考

トピック