Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

深層学習による変調の分類

この例では、畳み込みニューラル ネットワーク (CNN) を変調分類に使用する方法を説明します。チャネルで劣化した合成波形を生成します。生成された波形を学習データとして使用して、変調分類用の CNN を学習させます。その後、ソフトウェア無線 (SDR) ハードウェアと無線信号を使用して CNN をテストします。

CNN を使用した変調タイプの予測

この例の学習済み CNN は、以下の 8 つのデジタル変調タイプと 3 つのアナログ変調タイプを認識します。

  • 2 位相偏移変調 (BPSK)

  • 直交位相偏移変調 (QPSK)

  • 8-ary 位相偏移変調 (8-PSK)

  • 16-ary 直交振幅変調 (16-QAM)

  • 64-ary 直交振幅変調 (64-QAM)

  • 4-ary パルス振幅変調 (PAM4)

  • ガウス周波数偏移変調 (GFSK)

  • 連続位相周波数偏移変調 (CPFSK)

  • ブロードキャスト FM (B-FM)

  • 両側波帯振幅変調 (DSB-AM)

  • 単側波帯振幅変調 (SSB-AM)

modulationTypes = categorical(["BPSK", "QPSK", "8PSK", ...
  "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
  "B-FM", "DSB-AM", "SSB-AM"]);

まず、学習済みネットワークを読み込みます。ネットワーク学習の詳細は、CNN の学習の節を参照してください。

load trainedModulationClassificationNetwork
trainedNet
trainedNet = 
  SeriesNetwork with properties:

         Layers: [28×1 nnet.cnn.layer.Layer]
     InputNames: {'Input Layer'}
    OutputNames: {'Output'}

この学習済み CNN は、チャネルで劣化した 1024 個のサンプルを取得し、各フレームの変調タイプを予測します。ライス マルチパス フェージング、中心周波数とサンプリング時間のドリフト、および AWGN で劣化したいくつかの PAM4 フレームを生成します。以下の関数を使用して、CNN をテストするための合成信号を生成します。その後、CNN を使用してフレームの変調タイプを予測します。

  • randi:ランダムなビットの生成

  • pammod: ビットの PAM4 変調

  • rcosdesign:ルート レイズド コサイン パルス整形フィルターの設計

  • filter:シンボルのパルス整形

  • comm.RicianChannel:ライス マルチパス チャネルの適用

  • comm.PhaseFrequencyOffset:クロック オフセットによる位相シフト/周波数シフトの適用

  • interp1:クロック オフセットによるタイミングのずれの適用

  • awgn:AWGN の追加

% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(123456)
% Random bits
d = randi([0 3], 1024, 1);
% PAM4 modulation
syms = pammod(d,4);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35,4,8);
tx = filter(filterCoeffs,1,upsample(syms,8));

% Channel
SNR = 30;
maxOffset = 5;
fc = 902e6;
fs = 200e3;
multipathChannel = comm.RicianChannel(...
  'SampleRate', fs, ...
  'PathDelays', [0 1.8 3.4] / 200e3, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4);

frequencyShifter = comm.PhaseFrequencyOffset(...
  'SampleRate', fs);

% Apply an independent multipath channel
reset(multipathChannel)
outMultipathChan = multipathChannel(tx);

% Determine clock offset factor
clockOffset = (rand() * 2*maxOffset) - maxOffset;
C = 1 + clockOffset / 1e6;

% Add frequency offset
frequencyShifter.FrequencyOffset = -(C-1)*fc;
outFreqShifter = frequencyShifter(outMultipathChan);

% Add sampling time drift
t = (0:length(tx)-1)' / fs;
newFs = fs * C;
tp = (0:length(tx)-1)' / newFs;
outTimeDrift = interp1(t, outFreqShifter, tp);

% Add noise
rx = awgn(outTimeDrift,SNR,0);

% Frame generation for classification
unknownFrames = helperModClassGetNNFrames(rx);

% Classification
[prediction1,score1] = classify(trainedNet,unknownFrames);

硬判定に似た分類器の予測を返します。このネットワークはフレームを PAM4 フレームとして正しく識別します。変調信号の生成の詳細については、関数 helperModClassGetModulator を参照してください。

prediction1
prediction1 = 7×1 categorical
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 

分類器は各フレームのスコアのベクトルも返します。スコアは、各フレームが予測された変調タイプである確率に相当します。スコアをプロットします。

helperModClassPlotScores(score1,modulationTypes)

変調分類やその他のタスクに CNN を使用する前に、まず、既知の (ラベル付きの) データを使用してネットワークを学習させる必要があります。この例の最初の部分は、変調器、フィルター、およびチャネル障害などの Communications Toolbox™ の機能を使用して、合成学習データを生成する方法を示しています。2 番目の部分では、変調分類のタスクのために CNN を定義し、学習させて、テストすることに焦点を当てています。3 番目の部分では、ソフトウェア無線 (SDR) プラットフォームを使用して、無線信号でネットワーク性能をテストします。

学習のための波形生成

各変調タイプに対して 10,000 フレームを生成します。このうち 80% は学習に使用され、10% は検証に使用され、10% はテストに使用されます。ネットワーク学習フェーズでは、学習フレームと検証フレームを使用します。最終的な分類精度は、テスト フレームを使用して得ます。各フレームの長さは 1024 サンプルで、サンプルレートは 200 kHz です。デジタル変調タイプでは、8 つのサンプルがシンボルを表します。ネットワークでは、(ビデオのように) 複数の連続したフレームではなく、単一のフレームに基づいて各判定が行われます。デジタル変調タイプとアナログ変調タイプの中心周波数を、それぞれ 902 MHz と 100 MHz と仮定します。

この例を迅速に実行するには、学習済みネットワークを使用して少数の学習フレームを生成します。コンピューターでネットワークに学習させるには、"Train network now" オプションを選択します (つまり、trainNow を true に設定します)。

trainNow = false;
if trainNow == true
  numFramesPerModType = 10000;
else
  numFramesPerModType = 200;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;

sps = 8;                % Samples per symbol
spf = 1024;             % Samples per frame
symbolsPerFrame = spf / sps;
fs = 200e3;             % Sample rate
fc = [902e6 100e6];     % Center frequencies

チャネル障害の作成

以下の障害があるチャネルを通して各フレームを渡します。

  • AWGN

  • ライス マルチパス フェージング

  • 中心周波数オフセットとサンプリング時間のドリフトをもたらすクロック オフセット

この例のネットワークは単一のフレームに基づいて判定を行うため、各フレームは独立したチャネルを通過しなければなりません。

AWGN

チャネルは 30 dB の SNR で AWGN を追加します。関数 awgn を使用してチャネルを実装します。

ライス マルチパス

チャネルは、comm.RicianChannelSystem object™ を使用して、ライス マルチパス フェージング チャネル経由で信号を渡します。対応する平均パス ゲインが [0 -2 -10] dB の [0 1.8 3.4] サンプルの遅延プロファイルを仮定します。K ファクターは 4、最大ドップラー シフトは 4 Hz です。これは、902 MHz での歩行速度に相当します。以下の設定を使用してチャネルを実装します。

クロック オフセット

クロック オフセットは、送信機と受信機の内部クロック ソースが不正確なことが原因で発生します。クロック オフセットは、信号をベースバンドにダウンコンバートするために使用される中心周波数と、デジタルからアナログへの変換器のサンプリング レートが理想値とは異なる値になる原因となります。このチャネル シミュレーターは、C=1+Δclock106 として表されるクロック オフセット係数 C を使用します。ここで、Δclock はクロック オフセットです。各フレームで、チャネルは範囲 [-maxΔclock maxΔclock] 内の一様分布の値のセットから、乱数の Δclock 値を生成します。ここで、maxΔclock は最大クロック オフセットです。クロック オフセットは 100 万分の 1 (ppm) 単位で測定されます。この例では、最大クロック オフセットを 5 ppm と仮定します。

maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);

周波数オフセット

クロック オフセット係数 C と中心周波数に基づいて、各フレームに周波数オフセットを適用します。comm.PhaseFrequencyOffsetを使用してチャネルを実装します。

サンプリング レート オフセット

クロック オフセット係数 C に基づいて、各フレームにサンプリング レート オフセットを適用します。関数interp1を使用してチャネルを実装し、新しいレートの C×fs でフレームをリサンプリングします。

結合されたチャネル

helperModClassTestChannel オブジェクトを使用して、3 つすべてのチャネル劣化要因をフレームに適用します。

channel = helperModClassTestChannel(...
  'SampleRate', fs, ...
  'SNR', SNR, ...
  'PathDelays', [0 1.8 3.4] / fs, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4, ...
  'MaximumClockOffset', 5, ...
  'CenterFrequency', 902e6)
channel = 
  helperModClassTestChannel with properties:

                    SNR: 30
        CenterFrequency: 902000000
             SampleRate: 200000
             PathDelays: [0 9.0000e-06 1.7000e-05]
       AveragePathGains: [0 -2 -10]
                KFactor: 4
    MaximumDopplerShift: 4
     MaximumClockOffset: 5

オブジェクト関数 info を使用して、チャネルに関する基本情報を表示できます。

chInfo = info(channel)
chInfo = struct with fields:
               ChannelDelay: 6
     MaximumFrequencyOffset: 4510
    MaximumSampleRateOffset: 1

波形生成

各変調タイプのチャネルで劣化したフレームを生成し、そのフレームに対応するラベルを付けて MAT ファイルに格納するループを作成します。データをファイルに保存しておけば、この例を実行するたびにデータを生成する必要がなくなります。データをより効率的に共有することもできます。

各フレームの先頭からランダムな数のサンプルを削除して、過渡状態を削除し、フレームがシンボルの境界に対してランダムな開始点をもつようにします。

% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(12)

tic

numModulationTypes = length(modulationTypes);

channelInfo = info(channel);
transDelay = 50;
pool = getPoolSafe();
if ~isa(pool,"parallel.ClusterPool")
  dataDirectory = fullfile(tempdir,"ModClassDataFiles");
else
  dataDirectory = uigetdir("","Select network location to save data files");
end
disp("Data file directory is " + dataDirectory)
Data file directory is C:\TEMP\ModClassDataFiles
fileNameRoot = "frame";

% Check if data files exist
dataFilesExist = false;
if exist(dataDirectory,'dir')
  files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot)));
  if length(files) == numModulationTypes*numFramesPerModType
    dataFilesExist = true;
  end
end

if ~dataFilesExist
  disp("Generating data and saving in data files...")
  [success,msg,msgID] = mkdir(dataDirectory);
  if ~success
    error(msgID,msg)
  end
  for modType = 1:numModulationTypes
    elapsedTime = seconds(toc);
    elapsedTime.Format = 'hh:mm:ss';
    fprintf('%s - Generating %s frames\n', ...
      elapsedTime, modulationTypes(modType))
    
    label = modulationTypes(modType);
    numSymbols = (numFramesPerModType / sps);
    dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs);
    modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs);
    if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
      % Analog modulation types use a center frequency of 100 MHz
      channel.CenterFrequency = 100e6;
    else
      % Digital modulation types use a center frequency of 902 MHz
      channel.CenterFrequency = 902e6;
    end
    
    for p=1:numFramesPerModType
      % Generate random data
      x = dataSrc();
      
      % Modulate
      y = modulator(x);
      
      % Pass through independent channels
      rxSamples = channel(y);
      
      % Remove transients from the beginning, trim to size, and normalize
      frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
      
      % Save data file
      fileName = fullfile(dataDirectory,...
        sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p));
      save(fileName,"frame","label")
    end
  end
else
  disp("Data files exist. Skip data generation.")
end
Generating data and saving in data files...
00:00:00 - Generating BPSK frames
00:00:01 - Generating QPSK frames
00:00:02 - Generating 8PSK frames
00:00:03 - Generating 16QAM frames
00:00:04 - Generating 64QAM frames
00:00:05 - Generating PAM4 frames
00:00:07 - Generating GFSK frames
00:00:08 - Generating CPFSK frames
00:00:09 - Generating B-FM frames
00:00:10 - Generating DSB-AM frames
00:00:11 - Generating SSB-AM frames
% Plot the amplitude of the real and imaginary parts of the example frames
% against the sample number
helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)

% Plot the spectrogram of the example frames
helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)

データストアの作成

生成された複素数波形を含むファイルを signalDatastore オブジェクトを使用して管理します。データストアは、個々のファイルはメモリに収まっても全体が収まるとは限らない場合に特に便利です。

frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);

実数配列への複素信号の変換

この例の深層学習ネットワークでは実数入力を想定していますが、受信信号には複素数のベースバンド サンプルが含まれています。複素信号を実数値の 4 次元配列に変換します。出力フレームのサイズは [1 x spf x 2 x N] です。ここで、最初のページ (3 番目の次元) は同相サンプル、2 番目のページは直交サンプルです。畳み込みフィルターのサイズを 1 x spf にすると、この方法により畳み込み層でも I と Q の情報が結合され、位相情報が使いやすくなります。詳細については、helperModClassIQAsPages を参照してください。

frameDSTrans = transform(frameDS,@helperModClassIQAsPages);

学習、検証、およびテストへの分割

次に、フレームを学習データ、検証データ、およびテスト データに分割します。詳細については、helperModClassSplitData を参照してください。

splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples];
[trainDSTrans,validDSTrans,testDSTrans] = helperModClassSplitData(frameDSTrans,splitPercentages);

メモリへのデータのインポート

ニューラル ネットワークの学習は反復的です。データストアでは、反復のたびにファイルからデータを読み取り、そのデータを変換してからネットワークの係数を更新します。データがコンピューターのメモリに収まる場合、データをファイルからメモリにインポートしておけば、繰り返し実行されるこのファイルからの読み取りと変換の処理が不要になり、学習が高速になります。この場合、データのファイルからの読み取りと変換が 1 回で済みます。

ファイル内のすべてのデータをメモリにインポートします。ファイルには framelabel の 2 つの変数があり、データストアに対するそれぞれの read の呼び出しは cell 配列を返します。ここで、最初の要素が frame で、2 番目の要素が label です。transform の関数 helperModClassReadFramehelperModClassReadLabel を使用してフレームとラベルを読み取ります。Parallel Computing Toolbox™ ライセンスがある場合は、"UseParallel" オプションを true に設定してreadallを使用し、変換関数の並列処理を有効にします。関数 readall は関数 read の出力を既定では最初の次元で連結するため、フレームを cell 配列で返し、手動で 4 番目の次元で連結します。

% Read the training and validation frames into the memory
pctExists = parallelComputingLicenseExists();
trainFrames = transform(trainDSTrans, @helperModClassReadFrame);
rxTrainFrames = readall(trainFrames,"UseParallel",pctExists);
rxTrainFrames = cat(4, rxTrainFrames{:});
validFrames = transform(validDSTrans, @helperModClassReadFrame);
rxValidFrames = readall(validFrames,"UseParallel",pctExists);
rxValidFrames = cat(4, rxValidFrames{:});

% Read the training and validation labels into the memory
trainLabels = transform(trainDSTrans, @helperModClassReadLabel);
rxTrainLabels = readall(trainLabels,"UseParallel",pctExists);
validLabels = transform(validDSTrans, @helperModClassReadLabel);
rxValidLabels = readall(validLabels,"UseParallel",pctExists);

CNN の学習

この例では、6 つの畳み込み層と 1 つの全結合層で構成される CNN を使用します。最後を除く各畳み込み層の後には、バッチ正規化層、正規化線形ユニット (ReLU) 活性化層、および最大プーリング層が続きます。最後の畳み込み層では、最大プーリング層が平均プーリング層に置き換えられます。出力層にはソフトマックス活性化があります。ネットワーク設計ガイドについては、深層学習のヒントとコツ (Deep Learning Toolbox)を参照してください。

modClassNet = helperModClassCNN(modulationTypes,sps,spf);

次に、ミニバッチ サイズが 1024 の SGDM ソルバーを使用するようにTrainingOptionsSGDM (Deep Learning Toolbox)を構成します。エポックの数を大きくしても学習の利点は大きくならないため、エポックの最大数は 20 に設定します。既定では、'ExecutionEnvironment' プロパティは 'auto' に設定されており、関数 trainNetwork は GPU を使用できる場合はそれを使用し、使用できない場合は CPU を使用します。GPU を使用するには、Parallel Computing Toolbox のライセンスが必要です。初期学習率を 3x10-1 に設定します。6 エポックごとに 0.75 ずつ学習率を下げます。学習の進行状況をプロットするため、'Plots' を 'training-progress' に設定します。ネットワークの学習には NVIDIA® GeForce RTX 3080 GPU で約 3 分を要します。

maxEpochs = 20;
miniBatchSize = 1024;
options = helperModClassTrainingOptions(maxEpochs,miniBatchSize,...
  numel(rxTrainLabels),rxValidFrames,rxValidLabels);

ネットワークを学習させるか、既に学習済みのネットワークを使用します。既定では、この例は学習済みネットワークを使用します。

if trainNow == true
  elapsedTime = seconds(toc);
  elapsedTime.Format = 'hh:mm:ss';
  fprintf('%s - Training the network\n', elapsedTime)
  trainedNet = trainNetwork(rxTrainFrames,rxTrainLabels,modClassNet,options);
else
  load trainedModulationClassificationNetwork
end

学習の進行状況のプロットで示されているように、このネットワークは約 20 エポックで 97% を超える精度に収束します。

テスト フレームの分類精度を取得して、学習済みネットワークを評価します。結果は、このネットワークがこの波形グループに対して約 95% の精度を達成していることを示しています。

elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Classifying test frames\n', elapsedTime)
00:00:25 - Classifying test frames
% Read the test frames into the memory
testFrames = transform(testDSTrans, @helperModClassReadFrame);
rxTestFrames = readall(testFrames,"UseParallel",pctExists);
rxTestFrames = cat(4, rxTestFrames{:});

% Read the test labels into the memory
testLabels = transform(testDSTrans, @helperModClassReadLabel);
rxTestLabels = readall(testLabels,"UseParallel",pctExists);

rxTestPred = classify(trainedNet,rxTestFrames);
testAccuracy = mean(rxTestPred == rxTestLabels);
disp("Test accuracy: " + testAccuracy*100 + "%")
Test accuracy: 95.4545%

テスト フレームの混同行列をプロットします。この行列で示されているように、このネットワークは 16-QAM フレームと 64-QAM フレームを混同します。この問題は、各フレームが 128 シンボルのみを伝送し、16-QAM は 64-QAM のサブセットであるために想定されています。SSB-AM 信号には DSB-AM 信号のスペクトルのちょうど半分が含まれているため、ネットワークは DSB-AM フレームと SSB-AM フレームも混同します。

figure
cm = confusionchart(rxTestLabels, rxTestPred);
cm.Title = 'Confusion Matrix for Test Data';
cm.RowSummary = 'row-normalized';
cm.Parent.Position = [cm.Parent.Position(1:2) 950 550];

SDR によるテスト

関数 helperModClassSDRTest を使用し、無線信号によって学習済みネットワークのパフォーマンスをテストします。このテストを実行するには、送信および受信専用の SDR が必要です。2 台の ADALM-PLUTO 無線機を使用するか、送信用の 1 台の ADALM-PLUTO 無線機と受信用の 1 台の USRP® 無線機を使用できます。Install Support Package for Analog Devices ADALM-PLUTO Radioが必要です。USRP® 無線機を使用する場合、Install Communications Toolbox Support Package for USRP Radioも必要です。関数 helperModClassSDRTest は学習信号の生成で使用されたものと同じ変調関数を使用し、ADALM-PLUTO 無線機を使用して送信します。チャネルをシミュレートする代わりに、信号の受信用に構成された SDR (ADALM-PLUTO 無線機または USRP® 無線機) を使用してチャネルで劣化した信号を取得します。前に使用したものと同じ関数 classify で学習済みネットワークを使用して変調タイプを予測します。次のコード セグメントを実行すると、混同行列が生成され、テスト精度が出力されます。

radioPlatform = "ADALM-PLUTO";

switch radioPlatform
  case "ADALM-PLUTO"
    if helperIsPlutoSDRInstalled() == true
      radios = findPlutoRadio();
      if length(radios) >= 2
        helperModClassSDRTest(radios);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
  case {"USRP B2xx","USRP X3xx","USRP N2xx"}
    if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true)
      txRadio = findPlutoRadio();
      rxRadio = findsdru();
      switch radioPlatform
        case "USRP B2xx"
          idx = contains({rxRadio.Platform}, {'B200','B210'});
        case "USRP X3xx"
          idx = contains({rxRadio.Platform}, {'X300','X310'});
        case "USRP N2xx"
          idx = contains({rxRadio.Platform}, 'N200/N210/USRP2');
      end
      rxRadio = rxRadio(idx);
      if (length(txRadio) >= 1) && (length(rxRadio) >= 1)
        helperModClassSDRTest(rxRadio);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
end

2 台の静止した ADALM-PLUTO 無線機を約 2 フィート離して使用している場合、次の混同行列でのネットワークの全体の精度は 99% を達成します。結果は実験の設定によって異なります。

その他の調査

フィルター数やフィルター サイズなどのハイパーパラメーター パラメーターの最適化や、層の追加やさまざまな活性化層の使用などのネットワーク構造の最適化により、精度を向上させることができます。

Communication Toolbox には、多くの追加の変調タイプとチャネル障害が用意されています。詳細は、変調伝播とチャネル モデルの節を参照してください。LTE ToolboxWLAN Toolbox、および 5G Toolbox を使用して、標準固有の信号を追加することもできます。Phased Array System Toolbox を使用してレーダー信号を追加することもできます。

補助ファイル

関数 helperModClassGetModulator では、変調信号の生成に使用される MATLAB® 関数が示されています。詳細については、次の関数と System object も参照してください。

ローカル関数

function pool = getPoolSafe()
if exist("gcp","file") && license('test','distrib_computing_toolbox')
  pool = gcp;
  if isempty(pool)
    pool = parpool;
  end
else
  pool = [];
end
end

参考文献

  1. O'Shea, T. J., J. Corgan, and T. C. Clancy."Convolutional Radio Modulation Recognition Networks."Preprint, submitted June 10, 2016. https://arxiv.org/abs/1602.04105

  2. O'Shea, T. J., T. Roy, and T. C. Clancy."Over-the-Air Deep Learning Based Radio Signal Classification."IEEE Journal of Selected Topics in Signal Processing.Vol. 12, Number 1, 2018, pp. 168–179.

  3. Liu, X., D. Yang, and A. E. Gamal."Deep Neural Network Architectures for Modulation Classification."Preprint, submitted January 5, 2018. https://arxiv.org/abs/1712.00443v3

関連するトピック