Main Content

ビーム選択用のニューラル ネットワーク

この例では、ニューラル ネットワークを使用してビーム選択タスクのオーバーヘッドを低減する方法を示します。この例では、通信チャネルの情報ではなく、受信機の位置のみを使用します。すべてのビーム ペアに対して網羅的なビーム探索を行う代わりに、選択した K 組のビーム ペアの中から探索することで、ビーム スイープのオーバーヘッドを低減できます。この例では、合計 16 組のビーム ペアをもつシステムでシミュレーションを行い、設計した機械学習アルゴリズムが、ビーム ペアの半分のみを網羅的に探索することで 90% の精度を達成できるという結果になります。

はじめに

ミリメートル波 (mmWave) 通信を有効にするには、ビーム管理技術を使用しなければなりません。これは、高周波数ではパス損失や遮断が大きいためです。ビーム管理は、良好な接続性に向けて最適なビーム ペア (送信ビームと対応する受信ビーム) を確立して保持するための、レイヤー 1 (物理レイヤー) とレイヤー 2 (メディア アクセス制御) の手順のセットです [1]。5G New Radio (NR) ビーム管理手順のシミュレーションについては、NR SSB のビーム スイーピング (5G Toolbox)およびCSI-RS を使用した NR ダウンリンク送信側のビーム調整 (5G Toolbox)の例を参照してください。

この例では、ユーザー端末 (UE) とアクセス ネットワーク ノード (gNB) の間で接続が確立されるときのビーム選択手順について考えます。5G NR において、初期アクセスのビーム選択手順はビーム スイープで構成されます。これを行うには、送信側と受信側ですべてのビームを網羅的に探索してから、最も強い基準信号受信強度 (RSRP) が得られるビーム ペアを選択する必要があります。mmWave 通信では多くのアンテナ素子 (つまり多くのビーム) が必要なため、すべてのビームを網羅的に探索すると計算コストが高くなり、初期アクセス時間が長くなります。

網羅的な探索を繰り返し実行することを避け、通信オーバーヘッドを低減するために、機械学習がビーム選択問題に適用されるようになりました。通常、ビーム選択問題は、ターゲット出力が最良のビーム ペア インデックスとなる分類タスクとして提起されます。LiDAR、GPS 信号、路傍カメラ イメージなどの外部情報は、各機械学習アルゴリズム ([2] ~ [6]) への入力として使用されます。具体的には、この帯域外情報が与えられると、学習済み機械学習モデルが K 組の適切なビーム ペアのセットを推奨します。すべてのビーム ペアを網羅的に探索する代わりに、シミュレーションは、選択された K 組のビーム ペアのみを探索することにより、ビーム スイープのオーバーヘッドを低減します。

この例では、ニューラル ネットワークを使用し、受信機の GPS 座標のみを使用してビーム選択を実行します。この例では、送信機と分布点の位置を固定して、学習サンプルのセットを生成します。各サンプルは、受信機の位置 (GPS データ) と真の最適なビーム ペア インデックス (送信と受信の両端ですべてのビーム ペアの網羅的探索を実行することによって検出) で構成されます。この例では、受信機の位置を入力として使用し、真の最適なビーム ペア インデックスを正解ラベルとして使用するニューラル ネットワークを設計し、それに学習させます。テスト フェーズの際、ニューラル ネットワークは最初に K 組の適切なビーム ペアを出力します。これらの K 組のビーム ペアに対する網羅的な探索が行われ、平均 RSRP が最も高いビーム ペアが、ニューラル ネットワークによって最終的な予測ビーム ペアとして選択されます。

この例では、平均 RSRP と上位 K 位の精度 ([2] ~ [6]) の 2 つのメトリクスを使用して、推奨された方法の有効性を測定します。次の図は、主な処理手順を示しています。

beamSelectionSchematic.png

rng(211);                           % Set RNG state for repeatability

学習データの生成

事前に記録されたデータでは、受信機が 6 メートル四方の周囲にランダムに配置され、16 組のビーム ペア (各端に 4 つのビーム、1 つの RF チェーンでアナログ ビームフォーミング) で構成されています。MIMO 散乱チャネルを設定した後、この例では、学習セット内の 200 の異なる受信機位置と、テスト セット内の 100 の異なる受信機位置を考慮します。事前に記録されたデータでは、2 次元の位置座標を使用します。具体的には、各サンプルの 3 番目の GPS 座標は常にゼロです。NR SSB Beam Sweeping の例のように、各位置に対して、SSB ベースのビーム スイープが実行され、16 組のビーム ペアすべてを網羅的に探索します。網羅的探索の際に AWGN が付加されるため、この例では各位置で 4 つの異なる試行を実行し、最も高い平均 RSRP をもつビーム ペアの選択によって、真の最適なビーム ペアを決定します。

新しい学習セットとテスト セットを生成するには、useSavedDataSaveData の logical 値を調整できます。データの再生成にはかなりの時間がかかることに注意してください。

useSavedData = true;
saveData = false;

if useSavedData
    load nnBS_prm.mat;              % Load beam selection system parameters
    load nnBS_TrainingData.mat;     % Load prerecorded training samples 
    %   (input: receiver's location; output: optimal beam pair indices)
    load nnBS_TestData.mat;         % Load prerecorded test samples
else

周波数とビーム スイープ角度の構成

    prm.NCellID = 1;                    % Cell ID
    prm.FreqRange = 'FR1';              % Frequency range: 'FR1' or 'FR2'   
    
    prm.CenterFreq = 2.5e9;             % Hz    
    prm.SSBlockPattern = 'Case B';      % Case A/B/C/D/E    
    prm.SSBTransmitted = [ones(1,4) zeros(1,0)]; % 4/8 or 64 in length
        
    prm.TxArraySize = [8 8];            % Transmit array size, [rows cols]
    prm.TxAZlim = [-163 177];           % Transmit azimuthal sweep limits
    prm.TxELlim = [-90 0];              % Transmit elevation sweep limits
    
    prm.RxArraySize = [2 2];            % Receive array size, [rows cols]    
    prm.RxAZlim = [-177 157];           % Receive azimuthal sweep limits
    prm.RxELlim = [0 90];               % Receive elevation sweep limits
    
    prm.ElevationSweep = false;         % Enable/disable elevation sweep
    prm.SNRdB = 30;                     % SNR, dB
    prm.RSRPMode = 'SSSwDMRS';          % {'SSSwDMRS', 'SSSonly'}
    
    prm = validateParams(prm);

同期信号バーストの構成

    txBurst = nrWavegenSSBurstConfig;
    txBurst.BlockPattern = prm.SSBlockPattern;
    txBurst.TransmittedBlocks = prm.SSBTransmitted;
    txBurst.Period = 20;
    txBurst.SubcarrierSpacingCommon = prm.SubcarrierSpacingCommon;

分布点の構成

    c = physconst('LightSpeed');   % Propagation speed
    prm.lambda = c/prm.CenterFreq; % Wavelength
    
    prm.rightCoorMax = 10;    % Maximum x-coordinate
    prm.topCoorMax = 10;      % Maximum y-coordinate
    prm.posTx = [3.5;4.2;0];  % Transmit array position, [x;y;z], meters           

    % Scatterer locations
    % Generate scatterers at random positions
    Nscat = 10;        % Number of scatterers 
    azRange = prm.TxAZlim(1):prm.TxAZlim(2);
    elRange = -90:90;    
            
    % More evenly spaced scatterers
    randAzOrder = round(linspace(1, length(azRange), Nscat));
    azAngInSph = azRange(randAzOrder(1:Nscat));   
    
    % Consider a 2-D area, i.e., the elevation angle is zero
    elAngInSph = zeros(size(azAngInSph));
    r = 2;            % radius
    [x,y,z] = sph2cart(deg2rad(azAngInSph),deg2rad(elAngInSph),r);
    prm.ScatPos = [x;y;z] + [prm.rightCoorMax/2;prm.topCoorMax/2;0];

アンテナ アレイの構成

    % Transmit array
    if prm.IsTxURA
        % Uniform rectangular array
        arrayTx = phased.URA(prm.TxArraySize,0.5*prm.lambda, ...
            'Element',phased.IsotropicAntennaElement('BackBaffled',true));
    else
        % Uniform linear array
        arrayTx = phased.ULA(prm.NumTx, ...
            'ElementSpacing',0.5*prm.lambda, ...
            'Element',phased.IsotropicAntennaElement('BackBaffled',true));
    end

    % Receive array
    if prm.IsRxURA
        % Uniform rectangular array
        arrayRx = phased.URA(prm.RxArraySize,0.5*prm.lambda, ...
            'Element',phased.IsotropicAntennaElement);
    else
        % Uniform linear array
        arrayRx = phased.ULA(prm.NumRx, ...
            'ElementSpacing',0.5*prm.lambda, ...
            'Element',phased.IsotropicAntennaElement);
    end

Tx/Rx 位置の決定

    % Receiver locations
    % Training data: X points around a rectangle: each side has X/4 random points
    % X: X/4 for around square, X/10 for validation => lcm(4,10) = 20 smallest
    NDiffLocTrain = 200;
    pointsEachSideTrain = NDiffLocTrain/4;
    prm.NDiffLocTrain = NDiffLocTrain;
    
    locationX = 2*ones(pointsEachSideTrain, 1);
    locationY = 2 + (8-2)*rand(pointsEachSideTrain, 1);
    
    locationX = [locationX; 2 + (8-2)*rand(pointsEachSideTrain, 1)];
    locationY = [locationY; 8*ones(pointsEachSideTrain, 1)];
    
    locationX = [locationX; 8*ones(pointsEachSideTrain, 1)];
    locationY = [locationY; 2 + (8-2)*rand(pointsEachSideTrain, 1)];  
    
    locationX = [locationX; 2 + (8-2)*rand(pointsEachSideTrain, 1)];
    locationY = [locationY; 2*ones(pointsEachSideTrain, 1)];   
    
    locationZ = zeros(size(locationX));
    locationMat = [locationX locationY locationZ];

    % Fixing receiver's location, run repeated simulations to consider
    % different realizations of AWGN
    prm.NRepeatSameLoc = 4;

    locationMatTrain = repelem(locationMat,prm.NRepeatSameLoc, 1);

    % Test data: Y points around a rectangle: each side has Y/4 random points
    % Different data than test, but a smaller number
    NDiffLocTest = 100;
    pointsEachSideTest = NDiffLocTest/4;
    prm.NDiffLocTest = NDiffLocTest;
    
    locationX = 2*ones(pointsEachSideTest, 1);
    locationY = 2 + (8-2)*rand(pointsEachSideTest, 1);
    
    locationX = [locationX; 2 + (8-2)*rand(pointsEachSideTest, 1)];
    locationY = [locationY; 8*ones(pointsEachSideTest, 1)];
    
    locationX = [locationX; 8*ones(pointsEachSideTest, 1)];
    locationY = [locationY; 2 + (8-2)*rand(pointsEachSideTest, 1)];  
    
    locationX = [locationX; 2 + (8-2)*rand(pointsEachSideTest, 1)];
    locationY = [locationY; 2*ones(pointsEachSideTest, 1)];   
    
    locationZ = zeros(size(locationX));
    locationMat = [locationX locationY locationZ];

    locationMatTest = repelem(locationMat,prm.NRepeatSameLoc,1);
    
    [optBeamPairIdxMatTrain,rsrpMatTrain] = hGenDataMIMOScatterChan('training',locationMatTrain,prm,txBurst,arrayTx,arrayRx,311);
    [optBeamPairIdxMatTest,rsrpMatTest] = hGenDataMIMOScatterChan('test',locationMatTest,prm,txBurst,arrayTx,arrayRx,411);
    
    % Save generated data
    if saveData
        save('nnBS_prm.mat','prm');
        save('nnBS_TrainingData.mat','optBeamPairIdxMatTrain','rsrpMatTrain','locationMatTrain');
        save('nnBS_TestData.mat','optBeamPairIdxMatTest','rsrpMatTest','locationMatTest');
    end
end

送信機と分布点の位置のプロット

figure
scatter(prm.posTx(1),prm.posTx(2),100,'r^','filled');
hold on;
scatter(prm.ScatPos(1,:),prm.ScatPos(2,:),100,[0.9290 0.6940 0.1250],'s','filled');
xlim([0 10])
ylim([0 10])
title('Transmitter and Scatterers Positions')
legend('Transmitter','Scatterers')
xlabel('x (m)')
ylabel('y (m)')

Figure contains an axes object. The axes object with title Transmitter and Scatterers Positions, xlabel x (m), ylabel y (m) contains 2 objects of type scatter. These objects represent Transmitter, Scatterers.

データ処理と可視化

次に、平均 RSRP が最も高いビーム ペアを真の最適なビーム ペアとしてラベル付けします。one-hot 符号化ラベルを categorical データに変換して、分類に使用します。最後に、categorical データを拡張して合計 16 クラスになるようにし、可能なビーム ペアの数と一致するようにします (クラスの要素数は異なる場合があります)。拡張の目的は、ニューラル ネットワークの出力が確実に目的の次元 16 になるようにすることです。

学習データの処理

% Choose the best beam pair by picking the one with the highest average RSRP
% (taking average over NRepeatSameLoc different trials at each location)
avgOptBeamPairIdxCellTrain = cell(size(optBeamPairIdxMatTrain, 1)/prm.NRepeatSameLoc, 1);
avgOptBeamPairIdxScalarTrain = zeros(size(optBeamPairIdxMatTrain, 1)/prm.NRepeatSameLoc, 1);
for locIdx = 1:size(optBeamPairIdxMatTrain, 1)/prm.NRepeatSameLoc
    avgRsrp = squeeze(rsrpMatTrain(:,:,locIdx));
    [~, targetBeamIdx] = max(avgRsrp(:));
    avgOptBeamPairIdxScalarTrain(locIdx) = targetBeamIdx;
    avgOptBeamPairIdxCellTrain{locIdx} = num2str(targetBeamIdx);
end

% Even though there are a total of 16 beam pairs, due to the fixed topology
% (transmitter/scatterers/receiver locations), it is possible
% that some beam pairs are never selected as an optimal beam pair
%
% Therefore, we augment the categories so 16 classes total are in the data
% (although some classes may have zero elements)
allBeamPairIdxCell = cellstr(string((1:prm.numBeams^2)'));
avgOptBeamPairIdxCellTrain = categorical(avgOptBeamPairIdxCellTrain, allBeamPairIdxCell);
NBeamPairInTrainData = numel(categories(avgOptBeamPairIdxCellTrain)); % Should be 16

テスト データの処理

% Decide the best beam pair by picking the one with the highest avg. RSRP
avgOptBeamPairIdxCellTest = cell(size(optBeamPairIdxMatTest, 1)/prm.NRepeatSameLoc, 1);
avgOptBeamPairIdxScalarTest = zeros(size(optBeamPairIdxMatTest, 1)/prm.NRepeatSameLoc, 1);
for locIdx = 1:size(optBeamPairIdxMatTest, 1)/prm.NRepeatSameLoc
    avgRsrp = squeeze(rsrpMatTest(:,:,locIdx));
    [~, targetBeamIdx] = max(avgRsrp(:));
    avgOptBeamPairIdxScalarTest(locIdx) = targetBeamIdx;
    avgOptBeamPairIdxCellTest{locIdx} = num2str(targetBeamIdx);
end
% Augment the categories such that the data has 16 classes total
avgOptBeamPairIdxCellTest = categorical(avgOptBeamPairIdxCellTest, allBeamPairIdxCell);
NBeamPairInTestData = numel(categories(avgOptBeamPairIdxCellTest)); % Should be 16

ニューラル ネットワークの入出力データの作成

trainDataLen = size(locationMatTrain, 1)/prm.NRepeatSameLoc;
trainOut = avgOptBeamPairIdxCellTrain;
sampledLocMatTrain = locationMatTrain(1:prm.NRepeatSameLoc:end, :);
trainInput = sampledLocMatTrain(1:trainDataLen, :);

% Take 10% data out of test data as validation data
valTestDataLen = size(locationMatTest, 1)/prm.NRepeatSameLoc;
valDataLen = round(0.1*size(locationMatTest, 1))/prm.NRepeatSameLoc;
testDataLen = valTestDataLen-valDataLen;
  
% Randomly shuffle the test data such that the distribution of the
% extracted validation data is closer to test data
rng(111)
shuffledIdx = randperm(prm.NDiffLocTest); 
avgOptBeamPairIdxCellTest = avgOptBeamPairIdxCellTest(shuffledIdx);
avgOptBeamPairIdxScalarTest = avgOptBeamPairIdxScalarTest(shuffledIdx);
rsrpMatTest = rsrpMatTest(:,:,shuffledIdx);

valOut = avgOptBeamPairIdxCellTest(1:valDataLen, :);
testOutCat = avgOptBeamPairIdxCellTest(1+valDataLen:end, :);

sampledLocMatTest = locationMatTest(1:prm.NRepeatSameLoc:end, :);
sampledLocMatTest = sampledLocMatTest(shuffledIdx, :);

valInput = sampledLocMatTest(1:valDataLen, :);
testInput = sampledLocMatTest(valDataLen+1:end, :);

学習データの最適なビーム ペア分布のプロット

各学習サンプル (合計 200 個) の位置と最適なビーム ペアをプロットします。各色は、1 つのビーム ペア インデックスを表します。つまり、同じ色のデータ点は同じクラスに属します。学習データ セットを増やして、各ビーム ペアの値が含まれるようにすることもできますが、ビーム ペアの実際の分布は分布点と送信機の位置に依存します。

figure
rng(111)    % for colors in plot
color = rand(NBeamPairInTrainData, 3);
uniqueOptBeamPairIdx = unique(avgOptBeamPairIdxScalarTrain);
for n = 1:length(uniqueOptBeamPairIdx)
    beamPairIdx = find(avgOptBeamPairIdxScalarTrain == uniqueOptBeamPairIdx(n));
    locX = sampledLocMatTrain(beamPairIdx, 1);
    locY = sampledLocMatTrain(beamPairIdx, 2);
    scatter(locX, locY, [], color(n, :)); 
    hold on;
end
scatter(prm.posTx(1),prm.posTx(2),100,'r^','filled');
scatter(prm.ScatPos(1,:),prm.ScatPos(2,:),100,[0.9290 0.6940 0.1250],'s','filled');
hold off
xlabel('x (m)')
ylabel('y (m)')
xlim([0 10])
ylim([0 10])
title('Optimal Beam Pair Indices (Training Data)')

Figure contains an axes object. The axes object with title Optimal Beam Pair Indices (Training Data), xlabel x (m), ylabel y (m) contains 18 objects of type scatter.

figure
histogram(trainOut)
title('Histogram of Optimal Beam Pair Indices (Training Data)')
xlabel('Beam Pair Index')
ylabel('Number of Occurrences')

Figure contains an axes object. The axes object with title Histogram of Optimal Beam Pair Indices (Training Data), xlabel Beam Pair Index, ylabel Number of Occurrences contains an object of type categoricalhistogram.

検証データの最適ビーム ペア分布のプロット

figure
rng(111)    % for colors in plot
color = rand(NBeamPairInTestData, 3);
uniqueOptBeamPairIdx = unique(avgOptBeamPairIdxScalarTest(1:valDataLen));
for n = 1:length(uniqueOptBeamPairIdx)
    beamPairIdx = find(avgOptBeamPairIdxScalarTest(1:valDataLen) == uniqueOptBeamPairIdx(n));
    locX = sampledLocMatTest(beamPairIdx, 1);
    locY = sampledLocMatTest(beamPairIdx, 2);
    scatter(locX, locY, [], color(n, :)); 
    hold on;
end
scatter(prm.posTx(1),prm.posTx(2),100,'r^','filled');
scatter(prm.ScatPos(1,:),prm.ScatPos(2,:),100,[0.9290 0.6940 0.1250],'s','filled');
hold off
xlabel('x (m)')
ylabel('y (m)')
xlim([0 10])
ylim([0 10])
title('Optimal Beam Pair Indices (Validation Data)')

Figure contains an axes object. The axes object with title Optimal Beam Pair Indices (Validation Data), xlabel x (m), ylabel y (m) contains 9 objects of type scatter.

figure
histogram(valOut)
title('Histogram of Optimal Beam Pair Indices (Validation Data)')
xlabel('Beam Pair Index')
ylabel('Number of Occurrences')

Figure contains an axes object. The axes object with title Histogram of Optimal Beam Pair Indices (Validation Data), xlabel Beam Pair Index, ylabel Number of Occurrences contains an object of type categoricalhistogram.

テスト データの最適ビーム ペア分布のプロット

figure
rng(111)    % for colors in plots
color = rand(NBeamPairInTestData, 3);
uniqueOptBeamPairIdx = unique(avgOptBeamPairIdxScalarTest(1+valDataLen:end));
for n = 1:length(uniqueOptBeamPairIdx)
    beamPairIdx = find(avgOptBeamPairIdxScalarTest(1+valDataLen:end) == uniqueOptBeamPairIdx(n));
    locX = sampledLocMatTest(beamPairIdx, 1);
    locY = sampledLocMatTest(beamPairIdx, 2);
    scatter(locX, locY, [], color(n, :)); 
    hold on;
end
scatter(prm.posTx(1),prm.posTx(2),100,'r^','filled');
scatter(prm.ScatPos(1,:),prm.ScatPos(2,:),100,[0.9290 0.6940 0.1250],'s','filled');
hold off
xlabel('x (m)')
ylabel('y (m)')
xlim([0 10])
ylim([0 10])
title('Optimal Beam Pair Indices (Test Data)')

Figure contains an axes object. The axes object with title Optimal Beam Pair Indices (Test Data), xlabel x (m), ylabel y (m) contains 16 objects of type scatter.

figure
histogram(testOutCat)
title('Histogram of Optimal Beam Pair Indices (Test Data)')
xlabel('Beam Pair Index')
ylabel('Number of Occurrences')

Figure contains an axes object. The axes object with title Histogram of Optimal Beam Pair Indices (Test Data), xlabel Beam Pair Index, ylabel Number of Occurrences contains an object of type categoricalhistogram.

ニューラル ネットワークの設計と学習

4 つの隠れ層を使用してニューラル ネットワークに学習させます。この設計は、[3] (4 つの隠れ層) と [5] (各層に 128 個のニューロンをもつ 2 つの隠れ層) で提唱されているもので、受信機の位置もニューラル ネットワークへの入力として考慮します。学習を有効にするには、doTraining の logical 値を調整します。

この例では、クラスに重みを付けるオプションも提供しています。発生頻度の高いクラスは重みが小さくなり、発生頻度が低いクラスは重みが大きくなります。クラスの重み付けを使用するには、useDiffClassWeights の logical 値を調整します。

ネットワークを変更して、さまざまな設計を実験します。提供されたデータ セットのいずれかを変更する場合は、変更されたデータ セットでネットワークに再学習させなければなりません。ネットワークの再学習には、かなりの時間がかかる場合があります。後続の実行で学習済みネットワークを使用するには、saveNet の logical 値を調整します。

doTraining = false;
useDiffClassWeights = false;
saveNet = false;

if doTraining    
    if useDiffClassWeights
        catCount = countcats(trainOut);
        catFreq = catCount/length(trainOut);
        nnzIdx = (catFreq ~= 0);
        medianCount = median(catFreq(nnzIdx));
        classWeights = 10*ones(size(catFreq));
        classWeights(nnzIdx) = medianCount./catFreq(nnzIdx);
        filename = 'nnBS_trainedNetwWeighting.mat';
    else
        classWeights = ones(1,NBeamPairInTestData);
        filename = 'nnBS_trainedNet.mat';        
    end
    
    % Neural network design
    layers = [ ...
        featureInputLayer(3,'Name','input','Normalization','rescale-zero-one') 
        
        fullyConnectedLayer(96,'Name','linear1')
        leakyReluLayer(0.01,'Name','leakyRelu1')
        
        fullyConnectedLayer(96,'Name','linear2')
        leakyReluLayer(0.01,'Name','leakyRelu2')    
    
        fullyConnectedLayer(96,'Name','linear3')
        leakyReluLayer(0.01,'Name','leakyRelu3') 
    
        fullyConnectedLayer(96,'Name','linear4')
        leakyReluLayer(0.01,'Name','leakyRelu4')  
    
        fullyConnectedLayer(NBeamPairInTrainData,'Name','linear5')
        softmaxLayer('Name','softmax')
        classificationLayer('ClassWeights',classWeights,'Classes',allBeamPairIdxCell,'Name','output')];
    
    maxEpochs = 1000;
    miniBatchSize = 256;
    
    options = trainingOptions('adam', ...
        'MaxEpochs',maxEpochs, ...
        'MiniBatchSize',miniBatchSize, ...
        'InitialLearnRate',1e-4, ...    
        'ValidationData',{valInput,valOut}, ...
        'ValidationFrequency',500, ...
        'OutputNetwork', 'best-validation-loss', ...
        'Shuffle','every-epoch', ...
        'Plots','training-progress', ...
        'ExecutionEnvironment','cpu', ...
        'Verbose',0);
    
    % Train the network
    net = trainNetwork(trainInput,trainOut,layers,options);

    if saveNet
        save(filename,'net');
    end
else
    if useDiffClassWeights
        load 'nnBS_trainedNetwWeighting.mat';
    else
        load 'nnBS_trainedNet.mat';
    end
end

異なるアプローチの比較: 上位 K 位の精度

このセクションでは、上位 K 位の精度のメトリクスを考慮して、目には見えないテスト データで学習済みのネットワークをテストします。上位 K 位の精度のメトリクスは、ニューラル ネットワークベースのビーム選択タスク ([2] ~ [6]) で広く使用されています。

受信機の位置が与えられると、ニューラル ネットワークは最初に K 組の推奨ビーム ペアを出力します。次に、これらの K 組のビーム ペアに対して網羅的なシーケンシャル探索を実行し、平均 RSRP が最も高いものを最終予測として選択します。最終的に選択されたビーム ペアが真の最適なビーム ペアであった場合、予測は成功します。同様に、ニューラル ネットワークによって推奨される K 組のビーム ペアのうちの 1 つが真の最適なビーム ペアであった場合、成功と見なされます。

3 つのベンチマークを比較します。方式ごとに K 個の推奨ビーム ペアが生成されます。

  1. KNN - この手法は、テスト サンプルのために、最初に GPS 座標に基づいて K 個の最も近い学習サンプルを収集します。この手法では、これらの K 個の学習サンプルに関連付けられたすべてのビーム ペアを推奨します。各学習サンプルには対応する最適なビーム ペアがあるため、推奨されるビーム ペアの数は多くても K 組になります (一部のビーム ペアは同じである可能性がある)。

  2. 統計的情報 [5] - この手法は、最初に学習セット内での相対頻度に従ってすべてのビーム ペアをランク付けし、次に最初の K 組のビーム ペアを常に選択します。

  3. ランダム [5] - この手法は、テスト サンプルのために、ランダムに K 組のビーム ペアを選択します。

このプロットは、K=8 の場合の精度が既に 90% を超えていることを示しています。これは、ビーム選択タスクに学習済みニューラル ネットワークを使用することの有効性を強調しています。K=16 の場合、すべての方式 (KNN 以外) で 16 組のビーム ペアすべての網羅的な探索に緩和され、100% の精度が達成されます。しかし、K=16, KNN の場合は、16 個の最も近い学習サンプルが考慮され、これらのサンプルからの "異なる" ビーム ペアの数は多くの場合 16 未満になります。そのため、KNN は 100% の精度を達成しません。

rng(111)    % for repeatability of the "Random" policy
testOut = avgOptBeamPairIdxScalarTest(1+valDataLen:end, :);
statisticCount = countcats(testOutCat);
predTestOutput = predict(net,testInput,'ExecutionEnvironment','cpu');

K = prm.numBeams^2;
accNeural = zeros(1,K);
accKNN = zeros(1,K);
accStatistic = zeros(1,K);
accRandom = zeros(1,K);                
for k = 1:K    
    predCorrectNeural = zeros(testDataLen,1);      
    predCorrectKNN = zeros(testDataLen,1); 
    predCorrectStats = zeros(testDataLen,1);  
    predCorrectRandom = zeros(testDataLen,1);
    knnIdx = knnsearch(trainInput,testInput,'K',k);

    for n = 1:testDataLen 
        trueOptBeamIdx = testOut(n);  

        % Neural Network
        [~, topKPredOptBeamIdx] = maxk(predTestOutput(n, :),k);
        if sum(topKPredOptBeamIdx == trueOptBeamIdx) > 0 
            % if true, then the true correct index belongs to one of the K predicted indices
            predCorrectNeural(n,1) = 1;
        end 
        
        % KNN
        neighborsIdxInTrainData = knnIdx(n,:);
        topKPredOptBeamIdx= avgOptBeamPairIdxScalarTrain(neighborsIdxInTrainData);      
        if sum(topKPredOptBeamIdx == trueOptBeamIdx) > 0 
            % if true, then the true correct index belongs to one of the K predicted indices
            predCorrectKNN(n,1) = 1;
        end  
        
        % Statistical Info
        [~, topKPredOptBeamIdx] = maxk(statisticCount,k);
        if sum(topKPredOptBeamIdx == trueOptBeamIdx) > 0 
            % if true, then the true correct index belongs to one of the K predicted indices
            predCorrectStats(n,1) = 1;
        end           
        
        % Random
        topKPredOptBeamIdx = randperm(prm.numBeams*prm.numBeams,k);
        if sum(topKPredOptBeamIdx == trueOptBeamIdx) > 0 
            % if true, then the true correct index belongs to one of the K predicted indices
            predCorrectRandom(n,1) = 1;
        end                  

    end

    accNeural(k)    = sum(predCorrectNeural)/testDataLen*100;
    accKNN(k)       = sum(predCorrectKNN)/testDataLen*100;
    accStatistic(k) = sum(predCorrectStats)/testDataLen*100;
    accRandom(k)    = sum(predCorrectRandom)/testDataLen*100;    
    
end

figure
lineWidth = 1.5;
colorNeural = [0 0.4470 0.7410];
colorKNN = [0.8500 0.3250 0.0980];
colorStats = [0.4940 0.1840 0.5560];
colorRandom = [0.4660 0.6740 0.1880];
plot(1:K,accNeural,'--*','LineWidth',lineWidth,'Color',colorNeural)
hold on
plot(1:K,accKNN,'--o','LineWidth',lineWidth,'Color',colorKNN)
plot(1:K,accStatistic,'--s','LineWidth',lineWidth,'Color',colorStats)
plot(1:K,accRandom,'--d','LineWidth',lineWidth,'Color',colorRandom)
hold off
grid on
xticks(1:K)
xlabel('$K$','interpreter','latex')
ylabel('Top-$K$ Accuracy','interpreter','latex')
title('Performance Comparison of Different Beam Pair Selection Schemes')
legend('Neural Network','KNN','Statistical Info','Random','Location','best')

Figure contains an axes object. The axes object with title Performance Comparison of Different Beam Pair Selection Schemes, xlabel $K$, ylabel Top-$K$ Accuracy contains 4 objects of type line. These objects represent Neural Network, KNN, Statistical Info, Random.

異なるアプローチの比較: 平均 RSRP

目には見えないテスト データを使用して、ニューラル ネットワークと 3 つのベンチマークによって達成された平均 RSRP を計算します。プロットは、学習済みニューラル ネットワークを使用すると、最適な網羅的探索に近い平均 RSRP が得られることを示しています。

rng(111)    % for repeatability of the "Random" policy
K = prm.numBeams^2;
rsrpOptimal = zeros(1,K);
rsrpNeural = zeros(1,K);
rsrpKNN = zeros(1,K);
rsrpStatistic = zeros(1,K);
rsrpRandom = zeros(1,K);
for k = 1:K
    rsrpSumOpt = 0;
    rsrpSumNeural = 0;
    rsrpSumKNN = 0;
    rsrpSumStatistic = 0;
    rsrpSumRandom = 0;
    
    knnIdx = knnsearch(trainInput,testInput,'K',k);

    for n = 1:testDataLen
        % Exhaustive Search
        trueOptBeamIdx = testOut(n);  
        rsrp = rsrpMatTest(:,:,valDataLen+n);
        rsrpSumOpt = rsrpSumOpt + rsrp(trueOptBeamIdx);
        
        % Neural Network
        [~, topKPredOptCatIdx] = maxk(predTestOutput(n, :),k);    
        rsrpSumNeural = rsrpSumNeural + max(rsrp(topKPredOptCatIdx));         
      
        % KNN
        neighborsIdxInTrainData = knnIdx(n,:);
        topKPredOptBeamIdxKNN = avgOptBeamPairIdxScalarTrain(neighborsIdxInTrainData);    
        rsrpSumKNN = rsrpSumKNN + max(rsrp(topKPredOptBeamIdxKNN));  
        
        % Statistical Info
        [~, topKPredOptCatIdxStat] = maxk(statisticCount,k);
        rsrpSumStatistic = rsrpSumStatistic + max(rsrp(topKPredOptCatIdxStat));
        
        % Random
        topKPredOptBeamIdxRand = randperm(prm.numBeams*prm.numBeams,k);
        rsrpSumRandom = rsrpSumRandom + max(rsrp(topKPredOptBeamIdxRand));        
    end    
    rsrpOptimal(k)  = rsrpSumOpt/testDataLen/prm.NRepeatSameLoc;
    rsrpNeural(k)   = rsrpSumNeural/testDataLen/prm.NRepeatSameLoc;
    rsrpKNN(k)      = rsrpSumKNN/testDataLen/prm.NRepeatSameLoc;
    rsrpStatistic(k) = rsrpSumStatistic/testDataLen/prm.NRepeatSameLoc;
    rsrpRandom(k)   = rsrpSumRandom/testDataLen/prm.NRepeatSameLoc;
end

figure
lineWidth = 1.5;
plot(1:K,rsrpOptimal,'--h','LineWidth',lineWidth,'Color',[0.6350 0.0780 0.1840]);
hold on
plot(1:K,rsrpNeural,'--*','LineWidth',lineWidth,'Color',colorNeural)
plot(1:K,rsrpKNN,'--o','LineWidth',lineWidth,'Color',colorKNN)
plot(1:K,rsrpStatistic,'--s','LineWidth',lineWidth,'Color',colorStats)
plot(1:K,rsrpRandom,'--d','LineWidth',lineWidth, 'Color',colorRandom)
hold off
grid on
xticks(1:K)
xlabel('$K$','interpreter','latex')
ylabel('Average RSRP')
title('Performance Comparison of Different Beam Pair Selection Schemes')
legend('Exhaustive Search','Neural Network','KNN','Statistical Info','Random','Location','best')

Figure contains an axes object. The axes object with title Performance Comparison of Different Beam Pair Selection Schemes, xlabel $K$, ylabel Average RSRP contains 5 objects of type line. These objects represent Exhaustive Search, Neural Network, KNN, Statistical Info, Random.

最適な手法、ニューラル ネットワーク、および KNN でのアプローチの最終値を比較します。

[rsrpOptimal(end-3:end); rsrpNeural(end-3:end); rsrpKNN(end-3:end);]
ans = 3×4

   80.7363   80.7363   80.7363   80.7363
   80.7363   80.7363   80.7363   80.7363
   80.5067   80.5068   80.5069   80.5212

KNN と最適な手法の間でパフォーマンスのギャップがあることは、より大きなビーム ペアのセット (たとえば 256) を考慮した場合でも、KNN がうまく機能しない可能性があることを示しています。

混同行列のプロット

要素が少ないクラスは、学習済みネットワークで悪影響を受けることがわかります。異なるクラスに対して異なる重みを使用することで、これを回避できる可能性があります。useDiffClassWeights の logical 値を使用し、クラスごとにカスタムの重みを指定して同じように調査します。

predLabels = classify(net,testInput,'ExecutionEnvironment','cpu');
figure;
cm = confusionchart(testOutCat,predLabels);
title('Confusion Matrix')

Figure contains an object of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title Confusion Matrix.

まとめとその他の調査

この例では、5G NR システムのビーム選択タスクへのニューラル ネットワークの応用について説明しています。K 組の適切なビーム ペアのセットを出力するニューラル ネットワークを設計し、それに学習させることができます。ビーム スイープのオーバーヘッドは、選択された K 組のビーム ペアのみを網羅的に探索することで低減できます。

この例では、MIMO チャネル内の分布点を指定できます。ビーム選択におけるチャネルの影響を確認するには、さまざまなシナリオを試してください。この例では、さまざまなネットワーク構造とハイパーパラメーターの学習を実験するために使用できる保存済みのデータセットも提供しています。

シミュレーションの結果から、16 組のビーム ペアに対する事前記録された MIMO 散乱チャネルの場合、提案されたアルゴリズムは、K=8 の場合に上位 K 位の精度 90% を達成できます。これは、ニューラル ネットワークを使用する場合に、すべてのビーム ペアの半分のみを網羅的に探索すればよく、ビーム スイープのオーバーヘッドを 50% 削減できることを示しています。他のシステム パラメーターを変更して実験し、データを再生成してネットワークの有効性を確認した後、ネットワークに再学習させて再テストします。

参考文献

  1. 3GPP TR 38.802, "Study on New Radio access technology physical layer aspects." 3rd Generation Partnership Project; Technical Specification Group Radio Access Network.

  2. Klautau, A., González-Prelcic, N., and Heath, R. W., "LIDAR data for deep learning-based mmWave beam-selection," IEEE Wireless Communications Letters, vol. 8, no. 3, pp. 909–912, Jun. 2019.

  3. Heng, Y., and Andrews, J. G., "Machine Learning-Assisted Beam Alignment for mmWave Systems," 2019 IEEE Global Communications Conference (GLOBECOM), 2019, pp. 1-6, doi: 10.1109/GLOBECOM38437.2019.9013296.

  4. Klautau, A., Batista, P., González-Prelcic, N., Wang, Y., and Heath, R. W., "5G MIMO Data for Machine Learning: Application to Beam-Selection Using Deep Learning," 2018 Information Theory and Applications Workshop (ITA), 2018, pp. 1-9, doi: 10.1109/ITA.2018.8503086.

  5. Matteo, Z., <https://github.com/ITU-AI-ML-in-5G-Challenge/PS-012-ML5G-PHY-Beam-Selection_BEAMSOUP> (This is the team achieving the highest test score in the ITU Artificial Intelligence/Machine Learning in 5G Challenge in 2020).

  6. Sim, M. S., Lim, Y., Park, S. H., Dai, L., and Chae, C., "Deep Learning-Based mmWave Beam Selection for 5G NR/6G With Sub-6 GHz Channel Information: Algorithms and Prototype Validation," IEEE Access, vol. 8, pp. 51634-51646, 2020.

ローカル関数

function prm = validateParams(prm)
% Validate user specified parameters and return updated parameters
%
% Only cross-dependent checks are made for parameter consistency.

    if strcmpi(prm.FreqRange,'FR1')
        if prm.CenterFreq > 7.125e9 || prm.CenterFreq < 410e6
            error(['Specified center frequency is outside the FR1 ', ...
                   'frequency range (410 MHz - 7.125 GHz).']);
        end
        if strcmpi(prm.SSBlockPattern,'Case D') ||  ...
           strcmpi(prm.SSBlockPattern,'Case E')
            error(['Invalid SSBlockPattern for selected FR1 frequency ' ...
                'range. SSBlockPattern must be one of ''Case A'' or ' ...
                '''Case B'' or ''Case C'' for FR1.']);
        end
        if ~((length(prm.SSBTransmitted)==4) || ...
             (length(prm.SSBTransmitted)==8))
            error(['SSBTransmitted must be a vector of length 4 or 8', ...
                   'for FR1 frequency range.']);
        end
        if (prm.CenterFreq <= 3e9) && (length(prm.SSBTransmitted)~=4)
            error(['SSBTransmitted must be a vector of length 4 for ' ...
                   'center frequency less than or equal to 3GHz.']);
        end
        if (prm.CenterFreq > 3e9) && (length(prm.SSBTransmitted)~=8)
            error(['SSBTransmitted must be a vector of length 8 for ', ...
                   'center frequency greater than 3GHz and less than ', ...
                   'or equal to 7.125GHz.']);
        end
    else % 'FR2'
        if prm.CenterFreq > 52.6e9 || prm.CenterFreq < 24.25e9
            error(['Specified center frequency is outside the FR2 ', ...
                   'frequency range (24.25 GHz - 52.6 GHz).']);
        end
        if ~(strcmpi(prm.SSBlockPattern,'Case D') || ...
                strcmpi(prm.SSBlockPattern,'Case E'))
            error(['Invalid SSBlockPattern for selected FR2 frequency ' ...
                'range. SSBlockPattern must be either ''Case D'' or ' ...
                '''Case E'' for FR2.']);
        end
        if length(prm.SSBTransmitted)~=64
            error(['SSBTransmitted must be a vector of length 64 for ', ...
                   'FR2 frequency range.']);
        end
    end

    % Number of beams at transmit/receive ends
    prm.numBeams = sum(prm.SSBTransmitted);
    
    prm.NumTx = prod(prm.TxArraySize);
    prm.NumRx = prod(prm.RxArraySize);    
    if prm.NumTx==1 || prm.NumRx==1
        error(['Number of transmit or receive antenna elements must be', ... 
               ' greater than 1.']);
    end
    prm.IsTxURA = (prm.TxArraySize(1)>1) && (prm.TxArraySize(2)>1);
    prm.IsRxURA = (prm.RxArraySize(1)>1) && (prm.RxArraySize(2)>1);
    
    if ~( strcmpi(prm.RSRPMode,'SSSonly') || ...
          strcmpi(prm.RSRPMode,'SSSwDMRS') )
        error(['Invalid RSRP measuring mode. Specify either ', ...
               '''SSSonly'' or ''SSSwDMRS'' as the mode.']);
    end

    % Select SCS based on SSBlockPattern
    switch lower(prm.SSBlockPattern)
        case 'case a'
            scs = 15;
            cbw = 10;
            scsCommon = 15;
        case {'case b', 'case c'}
            scs = 30;
            cbw = 25;
            scsCommon = 30;
        case 'case d'
            scs = 120;
            cbw = 100;
            scsCommon = 120;
        case 'case e'
            scs = 240;
            cbw = 200;
            scsCommon = 120;
    end
    prm.SCS = scs;
    prm.ChannelBandwidth = cbw;
    prm.SubcarrierSpacingCommon = scsCommon;
end

参考

関数

オブジェクト

関連するトピック