Main Content

t-SNE の出力関数

t-SNE の出力関数の説明

tsne"出力関数" は、t-SNE アルゴリズムの最適化反復が NumPrint 回行われるごとに実行される関数です。出力関数では、プロットを作成するか、ファイルまたはワークスペース変数にデータを記録できます。この関数でアルゴリズムの進行状況を変更することはできませんが、反復の停止は可能です。

出力関数を設定するには、関数 tsne の名前と値のペアの引数 Options を使用します。Options には、statset または struct を使用して作成した構造体を設定します。構造体 Options'OutputFcn' フィールドには、関数ハンドルまたは関数ハンドルの cell 配列を設定します。

たとえば、outfun.m という名前の出力関数を設定するには、次のコマンドを使用します。

opts = statset('OutputFcn',@outfun);
Y = tsne(X,'Options',opts);

出力関数を記述するには、次の構文を使用します。

function stop = outfun(optimValues,state)

stop = false; % do not stop by default
switch state
    case 'init'
        % Set up plots or open files
    case 'iter'
        % Draw plots or update variables
    case 'done'
        % Clean up plots or files
end

tsne は、変数 state および optimValues を関数に渡します。コード スニペットに示されているように、state の値は 'init''iter' または 'done' です。

tsne optimValues 構造体

optimValues フィールド説明
'iteration'反復回数
'fval'最初の 99 回の反復における強調で修正したカルバック・ライブラー ダイバージェンス
'grad'最初の 99 回の反復における強調で修正したカルバック・ライブラー ダイバージェンスの勾配
'Exaggeration'現在の反復で使用している強調パラメーターの値
'Y'現在の埋め込み

t-SNE のカスタム出力関数

この例では、tsne で出力関数を使用する方法を示します。

カスタム出力関数

以下のコードは、次のタスクを実行する出力関数です。

  • カルバック・ライブラー ダイバージェンスおよびその勾配のノルムの履歴をワークスペース変数に保持する。

  • 反復の進行に応じて解と履歴をプロットする。

  • 情報を失わずに反復を早く停止するための [Stop] ボタンをプロットに表示する。

この出力関数には、正しいデータの分類をプロットで表示できるようにする追加の入力変数 species があります。species など追加のパラメーターを関数に含める方法については、関数のパラメーター化を参照してください。

function stop = KLLogging(optimValues,state,species)
persistent h kllog iters stopnow
switch state
    case 'init'
        stopnow = false;
        kllog = [];
        iters = [];
        h = figure;
        c = uicontrol('Style','pushbutton','String','Stop','Position', ...
            [10 10 50 20],'Callback',@stopme);
    case 'iter'
        kllog = [kllog; optimValues.fval,log(norm(optimValues.grad))];
        assignin('base','history',kllog)
        iters = [iters; optimValues.iteration];
        if length(iters) > 1
            figure(h)
            subplot(2,1,2)
            plot(iters,kllog);
            xlabel('Iterations')
            ylabel('Loss and Gradient')
            legend('Divergence','log(norm(gradient))')
            title('Divergence and log(norm(gradient))')
            subplot(2,1,1)
            gscatter(optimValues.Y(:,1),optimValues.Y(:,2),species)
            title('Embedding')
            drawnow
        end
    case 'done'
        % Nothing here
end
stop = stopnow;

function stopme(~,~)
stopnow = true;
end
end

カスタム出力関数の使用

tsne を使用して、4 次元データ セットであるフィッシャーのアヤメのデータを 2 次元でプロットします。初期の反復では発散が強調値によってスケーリングされているので、100 回目の反復で Divergence の値が低下しています。最後の数百回の反復では埋め込みの大部分が変化していないので、反復中に [Stop] ボタンをクリックして時間を節約できます。

load fisheriris
rng default % for reproducibility
opts = statset('OutputFcn',@(optimValues,state) KLLogging(optimValues,state,species));
Y = tsne(meas,'Options',opts,'Algorithm','exact');

関連するトピック