CSI フィードバック圧縮のための自己符号化器の学習
この例では、クラスター遅延線 (CDL) チャネルによってダウンリンク チャネル状態情報 (CSI) を圧縮するための自己符号化器ニューラル ネットワークの学習を行う方法を示します。
この例では、次のことを行います。
CSI フィードバック自己符号化のためのニューラル ネットワーク モデルの定義と学習。
前処理、符号化、復号化、および後処理を含む、完全な CSI 圧縮システムの学習済みネットワークのテスト。
システムの性能に対する量子化済みコードワードの影響のテスト。
CSI フィードバックのための AI ワークフロー
AI ベースの CSI フィードバック ワークフローの手順には、データの生成、データの準備、モデルの学習、およびモデルのテストが含まれます。各ステップを個別に実行することも、ステップを順番に実行することもできます。この例では "モデルの学習" に焦点を当てます。
CSI フィードバック プロセスと AI ワークフローの説明については、AI-Based CSI Feedbackを参照してください。概要として、ワークフローの手順は以下のとおりです。
1. データの生成 - Generate MIMO OFDM Channel Realizations for AI-Based Systemsの例に従って、チャネル推定データを生成します。
2. データの準備 - Preprocess Data for AI-Based CSI Feedback Compressionの例に従って、データの準備を行います。
3. モデルの学習 - モデルの学習では、この例のニューラル ネットワーク モデルの定義と学習セクションで説明されているように、前処理されたチャネル推定データをニューラル ネットワークに入力し、CSI データを再構成します。
4. モデルのテスト - モデルのテストは、Test AI-based CSI Compression Techniques for Enhanced PDSCH Throughputの例で重点的に説明されています。
自己符号化器モデルの学習、圧縮、テストを行うその他の例の一覧については、その他の調査セクションを参照してください。
ニューラル ネットワーク モデルの定義と学習
必要なデータがワークスペースに存在しない場合は、データを生成して準備します。データを前処理した後、prepareData 関数の出力 (inputData、systemParams、dataOptions、channel、および carrier) を調べることで、システムの構成を確認できます。
if ~exist("inputData","var") || ~exist("systemParams","var") ... || ~exist("dataOptions","var") || ~exist("channel","var") ... || ~exist("carrier","var") numSamples =1000; [inputData,systemParams,dataOptions,channel,carrier] = ... prepareData(numSamples); end
Starting channel realization generation 6 worker(s) running 00:00:13 - 100% Completed Starting CSI data preprocessing 6 worker(s) running 00:00:02 - 100% Completed
ニューラル ネットワーク モデルの変数の定義
ニューラル ネットワーク モデルを定義する変数を初期化します。inputData 変数には、××2 の配列から成る 個のサンプルが含まれています。
[maxDelay,nTx,Niq,Nsamples] = size(inputData)
maxDelay = 28
nTx = 8
Niq = 2
Nsamples = 2000
systemParams.MaxDelay = maxDelay;
データを学習セット、検証セット、テスト セットに分割します。
N = size(inputData, 4); numTrain = floor(N*10/15)
numTrain = 1333
numVal = floor(N*3/15)
numVal = 400
numTest = floor(N*2/15)
numTest = 266
inputDataT = inputData(:,:,:,1:numTrain); inputDataV = inputData(:,:,:,numTrain+(1:numVal)); inputDataTest = inputData(:,:,:,numTrain+numVal+(1:numTest));
この例では、[1] で推奨されている自己符号化器ニューラル ネットワークの修正版を使用します。
inputSize = [maxDelay nTx 2]; % 3rd dim for real and imaginary nLinear = prod(inputSize); nEncoded = 64; autoencoderNet = dlnetwork([ ... % Encoder imageInputLayer(inputSize, ... "Normalization","none","Name","Enc_Input") convolution2dLayer([3 3],2, ... "Padding","same","Name","Enc_Conv") batchNormalizationLayer("Epsilon",0.001,"Name","Enc_BN") leakyReluLayer(0.3,"Name","Enc_leakyRelu") flattenLayer("Name","Enc_flatten") fullyConnectedLayer(nEncoded,"Name","Enc_FC") sigmoidLayer("Name","Enc_Sigmoid") % Decoder fullyConnectedLayer(nLinear,"Name","Dec_FC") functionLayer(@(x)dlarray(reshape(x,maxDelay,nTx,2,[]),'SSCB'), ... "Formattable",true,"Acceleratable",true,"Name","Dec_Reshape") ]); autoencoderNet = ... helperCSINetAddResidualLayers(autoencoderNet, "Dec_Reshape"); autoencoderNet = addLayers(autoencoderNet, ... [convolution2dLayer([3 3],2,"Padding","same","Name","Dec_Conv") ... sigmoidLayer("Name","Dec_Sigmoid")]); autoencoderNet = ... connectLayers(autoencoderNet,"leakyRelu_2_3","Dec_Conv"); figure plot(autoencoderNet) title('CSI Compression Autoencoder')

ニューラル ネットワークの学習
自己符号化器ニューラル ネットワークの学習オプションを設定し、trainnet (Deep Learning Toolbox)関数を使用してネットワークに学習させます。学習は、Intel® Xeon® W-2133 CPU @ 3.60GHz と NVIDIA GeForce RTX 3080 GPU を使用して 13 分未満で完了します。事前学習済みのネットワークを読み込むには、trainNow を false に設定します。保存済みのネットワークは次の設定で動作することに注意してください。これらの設定を変更する場合は、trainNow を true に設定します。
txAntennaSize = [2 2 2 1 1]; % rows, columns, polarizations, panels
rxAntennaSize = [2 1 1 1 1]; % rows, columns, polarizations, panels
rmsDelaySpread = 300e-9; % s
maxDoppler = 5; % Hz
nSizeGrid = 52; % Number resource blocks (RB)
% 12 subcarriers per RB
subcarrierSpacing = 15;
trainNow =false; miniBatchSize = 1000; trainOptions = trainingOptions("adam", ... InitialLearnRate=0.01, ... LearnRateSchedule="piecewise", ... LearnRateDropPeriod=138, ... LearnRateDropFactor=0.7456, ... Epsilon=1e-7, ... MaxEpochs=1000, ... MiniBatchSize=miniBatchSize, ... Shuffle="every-epoch", ... ValidationData={inputDataV,inputDataV}, ... ValidationFrequency=20, ... ValidationPatience=20, ... Metrics="rmse", ... Verbose=true, ... OutputNetwork="best-validation-loss", ... ExecutionEnvironment="auto", ... Plots='none')
trainOptions =
TrainingOptionsADAM with properties:
GradientDecayFactor: 0.9000
MaxEpochs: 1000
InitialLearnRate: 0.0100
LearnRateSchedule: 'piecewise'
LearnRateDropFactor: 0.7456
LearnRateDropPeriod: 138
MiniBatchSize: 1000
Shuffle: 'every-epoch'
CheckpointFrequencyUnit: 'epoch'
PreprocessingEnvironment: 'serial'
Verbose: 1
VerboseFrequency: 50
ValidationData: {[28×8×2×400 single] [28×8×2×400 single]}
ValidationFrequency: 20
ValidationPatience: 20
Metrics: 'rmse'
ObjectiveMetricName: 'loss'
ExecutionEnvironment: 'auto'
Plots: 'none'
OutputFcn: []
SequenceLength: 'longest'
SequencePaddingValue: 0
SequencePaddingDirection: 'right'
InputDataFormats: "auto"
TargetDataFormats: "auto"
ResetInputNormalization: 1
ResetInverseNormalization: 1
NormalizeTargets: 0
BatchNormalizationStatistics: 'auto'
OutputNetwork: 'best-validation-loss'
Acceleration: "auto"
CheckpointPath: ''
CheckpointFrequency: 1
CategoricalInputEncoding: 'integer'
CategoricalTargetEncoding: 'auto'
L2Regularization: 1.0000e-04
GradientThresholdMethod: 'l2norm'
GradientThreshold: Inf
SquaredGradientDecayFactor: 0.9990
Epsilon: 1.0000e-07
lossFunc = @(x,t) nmseLossdB(x,t);
ネットワークの入力と出力との間の dB 単位の正規化平均二乗誤差 (NMSE) を学習損失関数として使用し、自己符号化器に最適な重みのセットを見つけます。
if trainNow [net,trainInfo] = ... trainnet(inputDataT,inputDataT,autoencoderNet,lossFunc,trainOptions); %#ok<UNRCH> save("csiTrainedNetwork_" ... + string(datetime("now","Format","dd_MM_HH_mm")), ... 'net','trainInfo','systemParams','dataOptions','trainOptions') else systemParamsCached = systemParams; load("csiTrainedNetwork202507",'net','trainInfo','systemParams','trainOptions') if ~checkSystemCompatibility(systemParams,systemParamsCached) error("CSIExample:Missmatch", ... "Saved network does not match settings. Set trainNow to true.") end end
学習済みネットワークのテスト
predict (Deep Learning Toolbox)関数を使用してテスト データを処理します。
Hhat = predict(net,inputDataTest);
自己符号化器ネットワークの入力と出力との間のコサイン類似度と NMSE を計算します。コサイン類似度は次のように定義されます。
ここで、 は自己符号化器の入力におけるチャネル推定で、 は自己符号化器の出力におけるチャネル推定です。コサイン類似度の詳細については、「Cosine Similarity As a Channel Estimate Quality Metric」の例を参照してください。NMSE は次のように定義されます。
ここで、 は自己符号化器の入力におけるチャネル推定で、 は自己符号化器の出力におけるチャネル推定です。
cossim = zeros(numTest,1); nmse = zeros(numTest,1); for n=1:numTest in = inputDataTest(:,:,1,n) + 1i*(inputDataTest(:,:,2,n)); out = Hhat(:,:,1,n) + 1i*(Hhat(:,:,2,n)); % Calculate correlation cossim(n) = helperComplexCosineSimilarity(in,out); % Calculate NMSE mse = mean(abs(in-out).^2,'all'); nmse(n) = 10*log10(mse / mean(abs(in).^2,'all')); end figure tiledlayout(3,1) nexttile histogram(abs(cossim),"Normalization","probability") grid on title(sprintf("Cosine Similarity Magnitude (Mean = %1.2f)", ... mean(abs(cossim),'all'))) xlabel("Cosine Similarity Magnitude"); ylabel("PDF") nexttile histogram(angle(cossim),"Normalization","probability") grid on title(sprintf("Cosine Similarity Angle (Mean = %1.2f)", ... mean(angle(cossim),'all'))) xlabel("Cosine Similarity Angle"); ylabel("PDF") nexttile histogram(nmse,"Normalization","probability") grid on title(sprintf("NMSE (Mean NMSE = %1.2f dB)", ... mean(nmse,'all'))) xlabel("NMSE (dB)"); ylabel("PDF")

完全な CSI フィードバック システム
次の図は、CSI フィードバックのチャネル推定に関する全体の処理を示しています。UE によって、CSI-RS 信号を使用して 1 つのスロットのチャネル応答 が推定されます。自己符号化器の符号化器部分を使用して、前処理されたチャネル推定 が符号化され、1 行 列の圧縮された配列が生成されます。自己符号化器の復号化器部分によって、圧縮された配列が圧縮解除され、 が取得されます。 を後処理することで、 が生成されます。

符号化された配列を取得するには、自己符号化器を符号化器ネットワークと復号化器ネットワークの 2 つの部分に分割します。
[encNet,decNet] = helperCSINetSplitEncoderDecoder(net,"Enc_Sigmoid");
plotNetwork(net,encNet,decNet)
チャネル推定を生成します。
numFrames = 100; nRx = prod(systemParams.RxAntennaSize); Hest = helper3GPPChannelRealizations(... numFrames, ... channel, ... carrier, ... UseParallel = false, ... SaveData = false, ... Verbose = false, ... ResetChannelPerFrame = true, ... NumSlotsPerFrame = 1);
チャネル推定を符号化および復号化します。
codeword = helperCSINetEncode(encNet,Hest,systemParams); Hhat = helperCSINetDecode(decNet,codeword,systemParams);
完全な CSI フィードバック システムについて、コサイン類似度と NMSE を計算します。
H = squeeze(mean(Hest,2)); nmseE2E = zeros(nRx,numFrames); cossimE2E = zeros(nRx,numFrames); for rx=1:nRx for n=1:numFrames out = Hhat(:,rx,:,n); in = H(:,rx,:,n); cossimE2E(rx,n) = mean(helperComplexCosineSimilarity(in,out)); nmseE2E(rx,n) = helperNMSE(in,out); end end figure tiledlayout(3,1) nexttile histogram(abs(cossimE2E),"Normalization","probability") grid on title(sprintf("Complete Cosine Similarity Magnitude (Mean = %1.2f)", ... mean(abs(cossimE2E),'all'))) xlabel("Cosine Similarity Magnitude"); ylabel("PDF") nexttile histogram(angle(cossimE2E),"Normalization","probability") grid on title(sprintf("Complete Cosine Similarity Angle (Mean = %1.2f)", ... mean(angle(cossimE2E),'all'))) xlabel("Cosine Similarity Angle"); ylabel("PDF") nexttile histogram(nmseE2E,"Normalization","probability") grid on title(sprintf("Complete NMSE (Mean NMSE = %1.2f dB)", ... mean(nmseE2E,'all'))) xlabel("NMSE (dB)"); ylabel("PDF")

量子化されたコードワードの効果
実際のシステムでは、符号化されたコードワードを少ないビット数で量子化しなければなりません。[2,10] ビットの範囲で量子化の影響をシミュレーションします。結果は、6 ビットで単精度のパフォーマンスを近似するのに十分であることを示しています。

maxVal = 1; minVal = -1; idxBits = 1; nBitsVec = 2:10; rhoQ = zeros(nRx,numFrames,length(nBitsVec)); nmseQ = zeros(nRx,numFrames,length(nBitsVec)); for numBits = nBitsVec disp("Running for " + numBits + " bit quantization") % Quantize between 0:2^n-1 to get bits qCodeword = uencode(double(codeword*2-1), numBits); % Get back the floating point, quantized numbers codewordRx = (single(udecode(qCodeword,numBits))+1)/2; Hhat = helperCSINetDecode(decNet,codewordRx,systemParams); H = squeeze(mean(Hest,2)); for rx=1:nRx for n=1:numFrames out = Hhat(:,rx,:,n); in = H(:,rx,:,n); rhoQ(rx,n,idxBits) = helperCSINetCorrelation(in,out); nmseQ(rx,n,idxBits) = helperNMSE(in,out); end end idxBits = idxBits + 1; end
Running for 2 bit quantization Running for 3 bit quantization Running for 4 bit quantization Running for 5 bit quantization Running for 6 bit quantization Running for 7 bit quantization Running for 8 bit quantization Running for 9 bit quantization Running for 10 bit quantization
figure tiledlayout(2,1) nexttile plot(nBitsVec,squeeze(mean(rhoQ,[1 2])),'*-') title("Correlation (Codeword-" + size(codeword,3) + ")") xlabel("Number of Quantization Bits"); ylabel("\rho") grid on nexttile plot(nBitsVec,squeeze(mean(nmseQ,[1 2])),'*-') title("NMSE (Codeword-" + size(codeword,3) + ")") xlabel("Number of Quantization Bits"); ylabel("NMSE (dB)") grid on

その他の調査
自己符号化器は、[624 8] の単精度複素チャネル推定配列を、平均相関係数が 0.99 で NMSE が -19.55 dB である [64 1] 単精度配列に圧縮します。6 ビットの量子化を使用すると、必要な CSI フィードバック データは 384 ビットのみとなり、圧縮率は約 800:1 になります。
display("Compression ratio is " + (624*8*32*2)/(64*6) + ":" + 1)
"Compression ratio is 832:1"
truncationFactor がシステムのパフォーマンスに与える影響を調査します。5G システムのパラメーター、チャネルのパラメーター、符号化されたシンボルの数を変更し、定義されたチャネルに最適な値を見つけます。
チャネル状態情報フィードバックを使用した NR PDSCH のスループットの例では、チャネル状態情報 (CSI) フィードバックを使用して物理ダウンリンク共有チャネル (PDSCH) のパラメーターを調整し、スループットを測定する方法を示しています。CSI フィードバック アルゴリズムを CSI 圧縮自己符号化器に置き換えて、パフォーマンスを比較します。
この例では、CSI の圧縮および再構成のための自己符号化器の設計、学習、および評価を行う方法を示します。タスクに固有のその他のプロセスについては、以下の例を参照してください。
CSI Feedback with Transformer Autoencoder — CSI の圧縮と再構成のためのトランスフォーマー自己符号化器を設計、学習、および評価します。
Optimize CSI Feedback Autoencoder Training Using MATLAB Parallel Server and Experiment Manager — MATLAB® Parallel Server™ と実験マネージャー アプリを使用して、チャネル状態情報 (CSI) の圧縮をシミュレーションする自己符号化器モデルの最適な学習ハイパーパラメーターの決定を高速化します。
CSI Feedback with Autoencoders Implemented on an FPGA (Deep Learning HDL Toolbox) — Deep Learning HDL Toolbox™ を使用して、実装済みの CSI 自己符号化器を FPGA に展開します。
MATLAB 上で動作する PyTorch ベースおよび Keras ベースのニューラル ネットワークの学習とテストの方法についても学ぶことができます。
補助関数
補助関数を調べて、システムの詳細な実装を確認します。
学習データの生成
helper3GPPChannelRealizations
ネットワークの定義と操作
helperCSINetDLNetwork
helperCSINetAddResidualLayers
helperCSINetSplitEncoderDecoder
CSI の処理
helperPreprocess3GPPChannelData
helperPostprocess3GPPChannelData
helperCSINetEncode
helperCSINetDecode
パフォーマンスの測定
helperComplexCosineSimilarity
helperNMSE
付録: 実験マネージャーを使ったハイパーパラメーターの最適化
最適なパラメーターを見つけるには、実験マネージャー アプリを使用します。CSITrainingProject.mlproj は事前構成されたプロジェクトです。プロジェクトを抽出します。
projectName = "CSITrainingProject"; if ~exist(projectName,"dir") projRoot = helperCSINetExtractProject(projectName); else projRoot = fullfile(exRoot(),projectName); end
プロジェクトを開くには、実験マネージャー アプリを起動し、次のファイルを開きます。
disp(fullfile(".","CSITrainingProject","CSITrainingProject.prj"))
.\CSITrainingProject\CSITrainingProject.prj
ハイパーパラメーター最適化時に使用する入力データと自己符号化器オプションを保存します。
dataDir = fullfile(pwd,"Data","processed"); if ~isfolder(dataDir) mkdir(dataDir) end save(fullfile(dataDir,"nr_channel_preprocessed.mat"), ... "inputData","systemParams") save('data_folder.mat', "dataDir");
ハイパーパラメーターの最適化実験では、次の図のように、ハイパーパラメーターの検索範囲が指定されたベイズ最適化が使用されます。プロジェクトを開いたら、実験設定関数 CSIAutoEncNN_setup を使用できます。カスタム メトリクス関数は E2E_NMSE です。


最適なパラメーターは、初期学習率が 0.01、学習率低下期間が 156 回の反復、学習率低下係数が 0.5916 です。最適なハイパーパラメーターを見つけたら、同じパラメーターを使ってネットワークを複数回学習させ、学習結果が最も良いネットワークを見つけます。

9 回目の試行で最適な E2E_NMSE が得られました。この例では、この学習済みのネットワークを保存済みのネットワークとして使用します。

バッチ モードの構成
実行の [モード] が Batch Sequential または Batch Simultaneous に設定されている場合、学習データは、[Prepare Data in Bulk] セクションの dataDir 変数によって定義された場所にあるワーカーにアクセスできなければなりません。dataDir を、ワーカーがアクセスできるネットワークの場所に設定してください。詳細については、Offload Experiments as Batch Jobs to a Cluster (Deep Learning Toolbox)を参照してください。
ローカル関数
function [inputData,systemParams,dataOptions,channel,carrier] = prepareData(numSamples) carrier = nrCarrierConfig; nSizeGrid = 52; % Number resource blocks (RB) systemParams.SubcarrierSpacing =15; % 15, 30, 60, 120 kHz carrier.NSizeGrid = nSizeGrid; carrier.SubcarrierSpacing = systemParams.SubcarrierSpacing; waveInfo = nrOFDMInfo(carrier); systemParams.TxAntennaSize = [2 2 2 1 1]; % rows, columns, polarization, panels systemParams.RxAntennaSize = [2 1 1 1 1]; % rows, columns, polarization, panels systemParams.MaxDoppler = 5; % Hz systemParams.RMSDelaySpread = 300e-9; % s systemParams.DelayProfile =
"CDL-C"; % CDL-A, CDL-B, CDL-C, CDL-D, CDL-D, CDL-E systemParams.NumSubcarriers = carrier.NSizeGrid*12; channel = nrCDLChannel; channel.DelayProfile = systemParams.DelayProfile; channel.DelaySpread = systemParams.RMSDelaySpread; % s channel.MaximumDopplerShift = systemParams.MaxDoppler; % Hz channel.RandomStream = "Global stream"; channel.TransmitAntennaArray.Size = systemParams.TxAntennaSize; channel.ReceiveAntennaArray.Size = systemParams.RxAntennaSize; channel.ChannelFiltering = false; channel.SampleRate = waveInfo.SampleRate; samplesPerSlot = ... sum(waveInfo.SymbolLengths(1:waveInfo.SymbolsPerSlot)); channel.NumTimeSamples = samplesPerSlot; % 1 slot worth of samples systemParams.NumSymbols = 14; useParallel =
true; saveData =
true; dataDir = fullfile(pwd,"Data"); dataFilePrefix = "nr_channel_est"; numSlotsPerFrame = 1; resetChannel = true; sdsChan = helper3GPPChannelRealizations(... numSamples, ... channel, ... carrier, ... UseParallel=useParallel, ... SaveData=saveData, ... DataDir=dataDir, ... dataFilePrefix=dataFilePrefix, ... NumSlotsPerFrame=numSlotsPerFrame, ... ResetChannelPerFrame=resetChannel); dataOptions.DataDomain =
"Frequency-Spatial (FS)"; dataOptions.TruncationFactor =
10; Tdelay = 1/(systemParams.NumSubcarriers*carrier.SubcarrierSpacing*1e3); rmsDelaySpreadSamples = channel.DelaySpread/Tdelay; [data,dataOptions] = helperPreprocess3GPPChannelData( ... sdsChan, ... TrainingObjective = "autoencoding", ... AverageOverSlots = true, ... TruncateChannel = true, ... ExpectedDelaySpreadSamples = rmsDelaySpreadSamples, ... TruncationFactor = dataOptions.TruncationFactor, ... DataComplexity = "real (2D)", ... IQDimension = 3, ... DataDomain = dataOptions.DataDomain, ... UseParallel = useParallel, ... SaveData = false); meanVal = mean(data{1},'all'); stdVal = std(data{1},[],'all'); inputData = (data{1}-meanVal) / stdVal; targetStd = 0.0212; inputData = inputData*targetStd+0.5; systemParams.Normalization = "mean-variance"; systemParams.MeanValue = meanVal; systemParams.StandardDeviationValue = stdVal; systemParams.TargetStandardDeviation = targetStd; systemParams.ExpectedDelaySpreadSamples = dataOptions.ExpectedDelaySpreadSamples; end function compatible = checkSystemCompatibility(systemParams,systemParamsCached) compatible = false; if systemParams.SubcarrierSpacing ~= systemParamsCached.SubcarrierSpacing return end if any(systemParams.TxAntennaSize ~= systemParamsCached.TxAntennaSize) return end if any(systemParams.RxAntennaSize ~= systemParamsCached.RxAntennaSize) return end if systemParams.MaxDoppler ~= systemParamsCached.MaxDoppler return end if systemParams.RMSDelaySpread ~= systemParamsCached.RMSDelaySpread return end if systemParams.DelayProfile ~= systemParamsCached.DelayProfile return end if systemParams.NumSubcarriers ~= systemParamsCached.NumSubcarriers return end if systemParams.NumSymbols ~= systemParamsCached.NumSymbols return end if abs(systemParams.MeanValue - systemParamsCached.MeanValue) > 3e-2 return end if (systemParams.StandardDeviationValue - systemParamsCached.StandardDeviationValue) > 3e-2 return end if systemParams.TargetStandardDeviation ~= systemParamsCached.TargetStandardDeviation return end if systemParams.ExpectedDelaySpreadSamples ~= systemParamsCached.ExpectedDelaySpreadSamples return end if systemParams.MaxDelay ~= systemParamsCached.MaxDelay return end compatible = true; end function plotNetwork(net,encNet,decNet) %plotNetwork Plot autoencoder network % plotNetwork(NET,ENC,DEC) plots the full autoencoder network together % with encoder and decoder networks. fig = figure; t1 = tiledlayout(1,2,'TileSpacing','Compact'); t2 = tiledlayout(t1,1,1,'TileSpacing','Tight'); t3 = tiledlayout(t1,2,1,'TileSpacing','Tight'); t3.Layout.Tile = 2; nexttile(t2) plot(net) title("Autoencoder") nexttile(t3) plot(encNet) title("Encoder") nexttile(t3) plot(decNet) title("Decoder") pos = fig.Position; pos(3) = pos(3) + 200; pos(4) = pos(4) + 300; pos(2) = pos(2) - 300; fig.Position = pos; end function rootDir = exRoot() %exRoot Example root directory rootDir = fileparts(which("helperCSINetDLNetwork")); end function loss = nmseLossdB(x,xHat) %nmseLossdB NMSE loss in dB in = complex(x(:,:,1,:),x(:,:,2,:)); out = complex(xHat(:,:,1,:),xHat(:,:,2,:)); nmsePerObservation = helperNMSE(in,out); loss = mean(nmsePerObservation); end
参考文献
[1] Wen, Chao-Kai, Wan-Ting Shih, and Shi Jin. “Deep Learning for Massive MIMO CSI Feedback.” IEEE Wireless Communications Letters 7, no. 5 (October 2018): 748–51. https://doi.org/10.1109/LWC.2018.2818160.







