このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
深層学習による変調の分類
この例では、畳み込みニューラル ネットワーク (CNN) を変調分類に使用する方法を説明します。チャネルで劣化した合成波形を生成します。生成された波形を学習データとして使用して、変調分類用の CNN を学習させます。その後、ソフトウェア無線 (SDR) ハードウェアと無線信号を使用して CNN をテストします。
CNN を使用した変調タイプの予測
この例の学習済み CNN は、以下の 8 つのデジタル変調タイプと 3 つのアナログ変調タイプを認識します。
2 位相偏移変調 (BPSK)
直交位相偏移変調 (QPSK)
8-ary 位相偏移変調 (8-PSK)
16-ary 直交振幅変調 (16-QAM)
64-ary 直交振幅変調 (64-QAM)
4-ary パルス振幅変調 (PAM4)
ガウス周波数偏移変調 (GFSK)
連続位相周波数偏移変調 (CPFSK)
ブロードキャスト FM (B-FM)
両側波帯振幅変調 (DSB-AM)
単側波帯振幅変調 (SSB-AM)
modulationTypes = categorical(sort(["BPSK", "QPSK", "8PSK", ... "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ... "B-FM", "DSB-AM", "SSB-AM"]));
まず、学習済みネットワークを読み込みます。ネットワーク学習の詳細は、CNN の学習の節を参照してください。
load trainedModulationClassificationNetwork
trainedNet
trainedNet = dlnetwork with properties: Layers: [19×1 nnet.cnn.layer.Layer] Connections: [18×2 table] Learnables: [22×3 table] State: [10×3 table] InputNames: {'Input Layer'} OutputNames: {'SoftMax'} Initialized: 1 View summary with summary.
この学習済み CNN は、チャネルで劣化した 1024 個のサンプルを取得し、各フレームの変調タイプを予測します。ライス マルチパス フェージング、中心周波数とサンプリング時間のドリフト、および AWGN で劣化したいくつかの PAM4 フレームを生成します。以下の関数を使用して、CNN をテストするための合成信号を生成します。その後、CNN を使用してフレームの変調タイプを予測します。
randi
:ランダムなビットの生成pammod
: ビットの PAM4 変調rcosdesign
:ルート レイズド コサイン パルス整形フィルターの設計filter
:シンボルのパルス整形comm.RicianChannel
:ライス マルチパス チャネルの適用comm.PhaseFrequencyOffset
:クロック オフセットによる位相シフト/周波数シフトの適用interp1
:クロック オフセットによるタイミングのずれの適用awgn
:AWGN の追加
% Set the random number generator to a known state to be able to regenerate % the same frames every time the simulation is run rng(123456) % Random bits d = randi([0 3], 1024, 1); % PAM4 modulation syms = pammod(d,4); % Square-root raised cosine filter filterCoeffs = rcosdesign(0.35,4,8); tx = filter(filterCoeffs,1,upsample(syms,8)); % Channel SNR = 30; maxOffset = 5; fc = 902e6; fs = 200e3; multipathChannel = comm.RicianChannel(... 'SampleRate', fs, ... 'PathDelays', [0 1.8 3.4] / 200e3, ... 'AveragePathGains', [0 -2 -10], ... 'KFactor', 4, ... 'MaximumDopplerShift', 4); frequencyShifter = comm.PhaseFrequencyOffset(... 'SampleRate', fs); % Apply an independent multipath channel reset(multipathChannel) outMultipathChan = multipathChannel(tx); % Determine clock offset factor clockOffset = (rand() * 2*maxOffset) - maxOffset; C = 1 + clockOffset / 1e6; % Add frequency offset frequencyShifter.FrequencyOffset = -(C-1)*fc; outFreqShifter = frequencyShifter(outMultipathChan); % Add sampling time drift t = (0:length(tx)-1)' / fs; newFs = fs * C; tp = (0:length(tx)-1)' / newFs; outTimeDrift = interp1(t, outFreqShifter, tp); % Add noise rx = awgn(outTimeDrift,SNR,0); % Frame generation for classification unknownFrames = helperModClassGetNNFrames(rx); % Classification scores1 = predict(trainedNet,unknownFrames); prediction1 = scores2label(scores1,modulationTypes);
硬判定に似た分類器の予測を返します。このネットワークはフレームを PAM4 フレームとして正しく識別します。変調信号の生成の詳細については、関数 helperModClassGetModulator を参照してください。
prediction1
prediction1 = 7×1 categorical
PAM4
PAM4
PAM4
PAM4
PAM4
PAM4
PAM4
分類器は各フレームのスコアのベクトルも返します。スコアは、各フレームが予測された変調タイプである確率に相当します。スコアをプロットします。
helperModClassPlotScores(scores1,modulationTypes)
変調分類やその他のタスクに CNN を使用する前に、まず、既知の (ラベル付きの) データを使用してネットワークを学習させる必要があります。この例の最初の部分は、変調器、フィルター、およびチャネル障害などの Communications Toolbox™ の機能を使用して、合成学習データを生成する方法を示しています。2 番目の部分では、変調分類のタスクのために CNN を定義し、学習させて、テストすることに焦点を当てています。3 番目の部分では、ソフトウェア無線 (SDR) プラットフォームを使用して、無線信号でネットワーク性能をテストします。
学習のための波形生成
各変調タイプに対して 10,000 フレームを生成します。このうち 80% は学習に使用され、10% は検証に使用され、10% はテストに使用されます。ネットワーク学習フェーズでは、学習フレームと検証フレームを使用します。最終的な分類精度は、テスト フレームを使用して得ます。各フレームの長さは 1024 サンプルで、サンプルレートは 200 kHz です。デジタル変調タイプでは、8 つのサンプルがシンボルを表します。ネットワークでは、(ビデオのように) 複数の連続したフレームではなく、単一のフレームに基づいて各判定が行われます。デジタル変調タイプとアナログ変調タイプの中心周波数を、それぞれ 902 MHz と 100 MHz と仮定します。
この例を迅速に実行するには、学習済みネットワークを使用して少数の学習フレームを生成します。コンピューターでネットワークに学習させるには、"Train network now" オプションを選択します (つまり、trainNow を true に設定します)。
trainNow = false; if trainNow == true numFramesPerModType = 10000; else numFramesPerModType = 200; end percentTrainingSamples = 80; percentValidationSamples = 10; percentTestSamples = 10; sps = 8; % Samples per symbol spf = 1024; % Samples per frame fs = 200e3; % Sample rate fc = [902e6 100e6]; % Center frequencies
チャネル障害の作成
以下の障害があるチャネルを通して各フレームを渡します。
AWGN
ライス マルチパス フェージング
中心周波数オフセットとサンプリング時間のドリフトをもたらすクロック オフセット
この例のネットワークは単一のフレームに基づいて判定を行うため、各フレームは独立したチャネルを通過しなければなりません。
AWGN
チャネルは 30 dB の SNR で AWGN を追加します。関数 awgn
を使用してチャネルを実装します。
ライス マルチパス
チャネルは、comm.RicianChannel
System object™ を使用して、ライス マルチパス フェージング チャネル経由で信号を渡します。対応する平均パス ゲインが [0 -2 -10] dB の [0 1.8 3.4] サンプルの遅延プロファイルを仮定します。K ファクターは 4、最大ドップラー シフトは 4 Hz です。これは、902 MHz での歩行速度に相当します。以下の設定を使用してチャネルを実装します。
クロック オフセット
クロック オフセットは、送信機と受信機の内部クロック ソースが不正確なことが原因で発生します。クロック オフセットは、信号をベースバンドにダウンコンバートするために使用される中心周波数と、デジタルからアナログへの変換器のサンプル レートが理想値とは異なる値になる原因となります。このチャネル シミュレーターは、 として表されるクロック オフセット係数 を使用します。ここで、 はクロック オフセットです。各フレームで、チャネルは範囲 [ ] 内の一様分布の値のセットから、乱数の 値を生成します。ここで、 は最大クロック オフセットです。クロック オフセットは 100 万分の 1 (ppm) 単位で測定されます。この例では、最大クロック オフセットを 5 ppm と仮定します。
maxDeltaOff = 5; deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff; C = 1 + (deltaOff/1e6);
周波数オフセット
クロック オフセット係数 と中心周波数に基づいて、各フレームに周波数オフセットを適用します。comm.PhaseFrequencyOffset
を使用してチャネルを実装します。
サンプル レート オフセット
クロック オフセット係数 に基づいて、各フレームにサンプル レート オフセットを適用します。関数interp1
を使用してチャネルを実装し、新しいレートの でフレームをリサンプリングします。
結合されたチャネル
helperModClassTestChannel オブジェクトを使用して、3 つすべてのチャネル劣化要因をフレームに適用します。
channel = helperModClassTestChannel(... 'SampleRate', fs, ... 'SNR', SNR, ... 'PathDelays', [0 1.8 3.4] / fs, ... 'AveragePathGains', [0 -2 -10], ... 'KFactor', 4, ... 'MaximumDopplerShift', 4, ... 'MaximumClockOffset', 5, ... 'CenterFrequency', 902e6)
channel = helperModClassTestChannel with properties: SNR: 30 CenterFrequency: 902000000 SampleRate: 200000 PathDelays: [0 9.0000e-06 1.7000e-05] AveragePathGains: [0 -2 -10] KFactor: 4 MaximumDopplerShift: 4 MaximumClockOffset: 5
オブジェクト関数 info を使用して、チャネルに関する基本情報を表示できます。
chInfo = info(channel)
chInfo = struct with fields:
ChannelDelay: 6
MaximumFrequencyOffset: 4510
MaximumSampleRateOffset: 1
波形生成
各変調タイプのチャネルで劣化したフレームを生成し、そのフレームに対応するラベルを付けて MAT ファイルに格納するループを作成します。データをファイルに保存しておけば、この例を実行するたびにデータを生成する必要がなくなります。データをより効率的に共有することもできます。
各フレームの先頭からランダムな数のサンプルを削除して、過渡状態を削除し、フレームがシンボルの境界に対してランダムな開始点をもつようにします。
% Set the random number generator to a known state to be able to regenerate % the same frames every time the simulation is run rng(12) tic numModulationTypes = length(modulationTypes); channelInfo = info(channel); transDelay = 50; pool = getPoolSafe(); if ~isa(pool,"parallel.ClusterPool") dataDirectory = fullfile(tempdir,"ModClassDataFiles"); else dataDirectory = uigetdir("","Select network location to save data files"); end disp("Data file directory is " + dataDirectory)
Data file directory is C:\TEMP\ModClassDataFiles
fileNameRoot = "frame"; % Check if data files exist dataFilesExist = false; if exist(dataDirectory,'dir') files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot))); if length(files) == numModulationTypes*numFramesPerModType dataFilesExist = true; end end if ~dataFilesExist disp("Generating data and saving in data files...") [success,msg,msgID] = mkdir(dataDirectory); if ~success error(msgID,msg) end for modType = 1:numModulationTypes elapsedTime = seconds(toc); elapsedTime.Format = 'hh:mm:ss'; fprintf('%s - Generating %s frames\n', ... elapsedTime, modulationTypes(modType)) label = modulationTypes(modType); numSymbols = (numFramesPerModType / sps); dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs); modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs); if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'}) % Analog modulation types use a center frequency of 100 MHz channel.CenterFrequency = 100e6; else % Digital modulation types use a center frequency of 902 MHz channel.CenterFrequency = 902e6; end for p=1:numFramesPerModType % Generate random data x = dataSrc(); % Modulate y = modulator(x); % Pass through independent channels rxSamples = channel(y); % Remove transients from the beginning, trim to size, and normalize frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps); % Save data file fileName = fullfile(dataDirectory,... sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p)); save(fileName,"frame","label") end end else disp("Data files exist. Skip data generation.") end
Generating data and saving in data files...
00:00:09 - Generating 16QAM frames 00:00:11 - Generating 64QAM frames 00:00:13 - Generating 8PSK frames 00:00:15 - Generating B-FM frames 00:00:17 - Generating BPSK frames 00:00:20 - Generating CPFSK frames 00:00:22 - Generating DSB-AM frames 00:00:24 - Generating GFSK frames 00:00:26 - Generating PAM4 frames 00:00:28 - Generating QPSK frames 00:00:30 - Generating SSB-AM frames
% Plot the amplitude of the real and imaginary parts of the example frames % against the sample number helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)
% Plot the spectrogram of the example frames
helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)
データストアの作成
生成された複素数波形を含むファイルを signalDatastore
オブジェクトを使用して管理します。データストアは、個々のファイルはメモリに収まっても全体が収まるとは限らない場合に特に便利です。
frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);
学習、検証、およびテストへの分割
次に、フレームを学習データ、検証データ、およびテスト データに分割します。詳細については、helperModClassSplitData を参照してください。
splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples]; [trainDS,validDS,testDS] = helperModClassSplitData(frameDS,splitPercentages);
メモリへのデータのインポート
ニューラル ネットワークの学習は反復的です。データストアでは、反復のたびにファイルからデータを読み取り、そのデータを変換してからネットワークの係数を更新します。データがコンピューターのメモリに収まる場合、データをファイルからメモリにインポートしておけば、繰り返し実行されるこのファイルからの読み取りと変換の処理が不要になり、学習が高速になります。この場合、データのファイルからの読み取りと変換が 1 回で済みます。
ファイル内のすべてのデータをメモリにインポートします。ファイルには frame
と label
の 2 つの変数があり、データストアに対するそれぞれの read
の呼び出しは cell 配列を返します。ここで、最初の要素が frame
で、2 番目の要素が label
です。transform
の関数 helperModClassReadFrame と helperModClassReadLabel を使用してフレームとラベルを読み取ります。Parallel Computing Toolbox™ ライセンスがある場合は、"UseParallel"
オプションを true
に設定してreadall
を使用し、変換関数の並列処理を有効にします。関数 readall
は関数 read
の出力を既定では最初の次元で連結するため、フレームを cell 配列で返し、手動で 4 番目の次元で連結します。
% Read the training and validation frames into the memory pctExists = parallelComputingLicenseExists(); trainFrames = transform(trainDS, @helperModClassReadFrame); rxTrainFrames = readall(trainFrames,"UseParallel",pctExists); validFrames = transform(validDS, @helperModClassReadFrame); rxValidFrames = readall(validFrames,"UseParallel",pctExists); % Read the training and validation labels into the memory trainLabels = transform(trainDS, @helperModClassReadLabel); rxTrainLabels = readall(trainLabels,"UseParallel",pctExists); validLabels = transform(validDS, @helperModClassReadLabel); rxValidLabels = readall(validLabels,"UseParallel",pctExists);
CNN の学習
この例では、5 つの畳み込み層と 1 つの全結合層で構成される CNN を使用します。最後を除く各畳み込み層の後には、バッチ正規化層、正規化線形ユニット (ReLU) 活性化層、および最大プーリング層が続きます。最後の畳み込み層では、最大プーリング層がグローバル平均プーリング層に置き換えられます。出力層にはソフトマックス活性化があります。ネットワーク設計ガイドについては、深層学習のヒントとコツ (Deep Learning Toolbox)を参照してください。
modClassNet = helperModClassCNN(modulationTypes,sps,spf);
次に、ミニバッチ サイズが 1024 の SGDM ソルバーを使用するようにTrainingOptionsSGDM
(Deep Learning Toolbox)を構成します。エポックの数を大きくしても学習の利点は大きくならないため、エポックの最大数は 20 に設定します。既定では、'ExecutionEnvironment'
プロパティは 'auto'
に設定されており、関数 trainNetwork
は GPU を使用できる場合はそれを使用し、使用できない場合は CPU を使用します。GPU を使用するには、Parallel Computing Toolbox のライセンスが必要です。初期学習率を に設定します。6 エポックごとに 0.75 ずつ学習率を下げます。学習の進行状況をプロットするため、'Plots'
を 'training-progress'
に設定します。ネットワークの学習には NVIDIA® GeForce RTX 3080 GPU で約 3 分を要します。
maxEpochs = 20; miniBatchSize = 1024; trainingPlots = "none"; metrics = []; verbose = true; validationFrequency = floor(numel(rxTrainLabels)/miniBatchSize); options = trainingOptions('sgdm', ... InitialLearnRate = 3e-1, ... MaxEpochs = maxEpochs, ... MiniBatchSize = miniBatchSize, ... Shuffle = 'every-epoch', ... Plots = trainingPlots, ... Verbose = verbose, ... ValidationData = {rxValidFrames,rxValidLabels}, ... ValidationFrequency = validationFrequency, ... ValidationPatience = 5, ... Metrics = metrics, ... LearnRateSchedule = 'piecewise', ... LearnRateDropPeriod = 6, ... LearnRateDropFactor = 0.75, ... OutputNetwork='best-validation-loss');
ネットワークを学習させるか、既に学習済みのネットワークを使用します。既定では、この例は学習済みネットワークを使用します。
if trainNow == true elapsedTime = seconds(toc); elapsedTime.Format = 'hh:mm:ss'; fprintf('%s - Training the network\n', elapsedTime) trainedNet = trainnet(rxTrainFrames,rxTrainLabels,modClassNet,"crossentropy",options); else load trainedModulationClassificationNetwork end
次のプロットは、trainingPlots
を [Training progress] に設定し、metric
を [Accuracy] に設定し、verbose
を false
に設定して実行した例を示しています。このネットワークは約 20 エポックで約 97% の精度に収束します。
テスト フレームの分類精度を取得して、学習済みネットワークを評価します。結果は、このネットワークがこの波形グループに対して約 96% の精度を達成していることを示しています。
elapsedTime = seconds(toc); elapsedTime.Format = 'hh:mm:ss'; fprintf('%s - Classifying test frames\n', elapsedTime)
00:05:21 - Classifying test frames
% Read the test frames into the memory testFrames = transform(testDS, @helperModClassReadFrame); rxTestFrames = readall(testFrames,"UseParallel",pctExists); % Read the test labels into the memory testLabels = transform(testDS, @helperModClassReadLabel); rxTestLabels = readall(testLabels,"UseParallel",pctExists); scores = predict(trainedNet,cat(3,rxTestFrames{:})); rxTestPred = scores2label(scores,modulationTypes); testAccuracy = mean(rxTestPred == rxTestLabels); disp("Test accuracy: " + testAccuracy*100 + "%")
Test accuracy: 96.4636%
テスト フレームの混同行列をプロットします。この行列で示されているように、このネットワークは 16-QAM フレームと 64-QAM フレームを混同します。この問題は、各フレームが 128 シンボルのみを伝送し、16-QAM は 64-QAM のサブセットであるために想定されています。SSB-AM 信号には DSB-AM 信号のスペクトルのちょうど半分が含まれているため、ネットワークは DSB-AM フレームと SSB-AM フレームも混同します。
figure cm = confusionchart(rxTestLabels, rxTestPred); cm.Title = 'Confusion Matrix for Test Data'; cm.RowSummary = 'row-normalized'; cm.Parent.Position = [cm.Parent.Position(1:2) 950 550];
SDR によるテスト
関数 helperModClassSDRTest を使用し、無線信号によって学習済みネットワークのパフォーマンスをテストします。このテストを実行するには、送信および受信専用の SDR が必要です。2 台の ADALM-PLUTO 無線機を使用するか、送信用の 1 台の ADALM-PLUTO 無線機と受信用の 1 台の USRP® 無線機を使用できます。Install Support Package for Analog Devices ADALM-PLUTO Radioが必要です。USRP® 無線機を使用する場合、Install Communications Toolbox Support Package for USRP Radioも必要です。関数 helperModClassSDRTest
は学習信号の生成で使用されたものと同じ変調関数を使用し、ADALM-PLUTO 無線機を使用して送信します。チャネルをシミュレートする代わりに、信号の受信用に構成された SDR (ADALM-PLUTO 無線機または USRP® 無線機) を使用してチャネルで劣化した信号を取得します。前に使用したものと同じ関数 perdict
で学習済みネットワークを使用して変調タイプを予測します。次のコード セグメントを実行すると、混同行列が生成され、テスト精度が出力されます。
radioPlatform = "ADALM-PLUTO"; switch radioPlatform case "ADALM-PLUTO" if helperIsPlutoSDRInstalled() == true radios = findPlutoRadio(); if length(radios) >= 2 helperModClassSDRTest(radios); else disp('Selected radios not found. Skipping over-the-air test.') end end case {"USRP B2xx","USRP X3xx","USRP N2xx"} if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true) txRadio = findPlutoRadio(); rxRadio = findsdru(); switch radioPlatform case "USRP B2xx" idx = contains({rxRadio.Platform}, {'B200','B210'}); case "USRP X3xx" idx = contains({rxRadio.Platform}, {'X300','X310'}); case "USRP N2xx" idx = contains({rxRadio.Platform}, 'N200/N210/USRP2'); end rxRadio = rxRadio(idx); if (length(txRadio) >= 1) && (length(rxRadio) >= 1) helperModClassSDRTest(rxRadio); else disp('Selected radios not found. Skipping over-the-air test.') end end end
2 台の静止した ADALM-PLUTO 無線機を約 2 フィート離して使用している場合、次の混同行列でのネットワークの全体の精度は 99% を達成します。結果は実験の設定によって異なります。
その他の調査
フィルター数やフィルター サイズなどのハイパーパラメーター パラメーターの最適化や、層の追加やさまざまな活性化層の使用などのネットワーク構造の最適化により、精度を向上させることができます。
Communication Toolbox には、多くの追加の変調タイプとチャネル障害が用意されています。詳細は、変調と伝播とチャネル モデルの節を参照してください。LTE Toolbox、WLAN Toolbox、および 5G Toolbox を使用して、標準固有の信号を追加することもできます。Phased Array System Toolbox を使用してレーダー信号を追加することもできます。
補助ファイル
関数 helperModClassGetModulator では、変調信号の生成に使用される MATLAB® 関数が示されています。詳細については、次の関数と System object も参照してください。
ローカル関数
function pool = getPoolSafe() if exist("gcp","file") && license('test','distrib_computing_toolbox') pool = gcp; if isempty(pool) pool = parpool; end else pool = []; end end
参考文献
O'Shea, T. J., J. Corgan, and T. C. Clancy."Convolutional Radio Modulation Recognition Networks."Preprint, submitted June 10, 2016. https://arxiv.org/abs/1602.04105
O'Shea, T. J., T. Roy, and T. C. Clancy."Over-the-Air Deep Learning Based Radio Signal Classification."IEEE Journal of Selected Topics in Signal Processing.Vol. 12, Number 1, 2018, pp. 168–179.
Liu, X., D. Yang, and A. E. Gamal."Deep Neural Network Architectures for Modulation Classification."Preprint, submitted January 5, 2018. https://arxiv.org/abs/1712.00443v3
関連するトピック
- MATLAB による深層学習 (Deep Learning Toolbox)