高速化された深層学習関数の出力のチェック
この例では、高速化された関数の出力が基になる関数の出力と一致するかどうかをチェックする方法を説明します。
高速化された関数の出力が、基になる関数の出力と異なる場合があります。たとえば、乱数生成を使用する関数 (ネットワークの入力に追加するランダム ノイズを生成する関数など) を高速化する場合は、注意しなければなりません。dlarray
オブジェクトではない乱数を生成する関数のトレースをキャッシュする場合、高速化された関数は生成された乱数をトレースにキャッシュします。このトレースを再利用すると、高速化された関数はキャッシュされた乱数値を使用します。高速化された関数は新しい乱数値を生成しません。
高速化された関数の出力が基になる関数の出力と一致するかどうかをチェックするには、高速化された関数の CheckMode
プロパティを使用します。高速化された関数の CheckMode
プロパティが 'tolerance'
で、指定された許容誤差よりも出力の差が大きい場合、高速化された関数によって警告がスローされます。
関数 dlaccelerate
を使用し、例の最後にリストされている関数 myUnsupportedFun
を高速化します。関数 myUnsupportedFun
は、ランダム ノイズを生成し、それを入力に追加します。この関数は dlarray
オブジェクトではない乱数を生成するため、高速化をサポートしていません。
accfun = dlaccelerate(@myUnsupportedFun)
accfun = AcceleratedFunction with properties: Function: @myUnsupportedFun Enabled: 1 CacheSize: 50 HitRate: 0 Occupancy: 0 CheckMode: 'none' CheckTolerance: 1.0000e-04
関数 clearCache
を使用し、過去にキャッシュされたトレースをクリアします。
clearCache(accfun)
再利用されたキャッシュに格納されたトレースの出力が、基になる関数の出力と一致するかどうかをチェックするには、CheckMode
プロパティを 'tolerance'
に設定します。
accfun.CheckMode = 'tolerance'
accfun = AcceleratedFunction with properties: Function: @myUnsupportedFun Enabled: 1 CacheSize: 50 HitRate: 0 Occupancy: 0 CheckMode: 'tolerance' CheckTolerance: 1.0000e-04
1 の配列を dlarray
入力として指定し、高速化された関数を評価します。
dlX = dlarray(ones(3,3)); dlY = accfun(dlX)
dlY = 3×3 dlarray 1.8147 1.9134 1.2785 1.9058 1.6324 1.5469 1.1270 1.0975 1.9575
同じ入力を使って、高速化された関数を再度評価します。高速化された関数は、新しい乱数値を生成せずに、キャッシュされたランダム ノイズの値を再利用するため、再利用されたトレースの出力は基になる関数の出力と一致しません。高速化された関数の CheckMode
プロパティが 'tolerance'
で、出力が異なる場合、高速化された関数によって警告がスローされます。
dlY = accfun(dlX)
Warning: Accelerated outputs differ from underlying function outputs.
dlY = 3×3 dlarray 1.8147 1.9134 1.2785 1.9058 1.6324 1.5469 1.1270 1.0975 1.9575
関数 rand
の 'like'
オプションを使用して dlarray
オブジェクトで乱数を生成した場合は、高速化がサポートされます。高速化された関数で乱数を生成する場合、その関数が関数 rand
を使用しており、トレースされた dlarray
オブジェクト (入力の dlarray
オブジェクトに依存する dlarray
オブジェクト) に 'like'
オプションが設定されていることを確認してください。
例の最後にリストされている関数 mySupportedFun
を高速化します。関数 mySupportedFun
は、トレースされた dlarray
オブジェクトで 'like'
オプションを使用してノイズを生成し、そのノイズを入力に追加します。
accfun2 = dlaccelerate(@mySupportedFun);
関数 clearCache
を使用し、過去にキャッシュされたトレースをクリアします。
clearCache(accfun2)
再利用されたキャッシュに格納されたトレースの出力が、基になる関数の出力と一致するかどうかをチェックするには、CheckMode
プロパティを 'tolerance'
に設定します。
accfun2.CheckMode = 'tolerance';
前回と同じ入力を使って、高速化された関数を 2 回評価します。再利用されたキャッシュの出力が基になる関数の出力と一致するため、高速化された関数は警告をスローしません。
dlY = accfun2(dlX)
dlY = 3×3 dlarray 1.7922 1.0357 1.6787 1.9595 1.8491 1.7577 1.6557 1.9340 1.7431
dlY = accfun2(dlX)
dlY = 3×3 dlarray 1.3922 1.7060 1.0462 1.6555 1.0318 1.0971 1.1712 1.2769 1.8235
出力が一致するかどうかをチェックする場合、追加の処理が必要となり、関数の評価に必要な分だけ処理時間が増加します。出力をチェックしたら、CheckMode
プロパティを 'none'
に設定します。
accfun1.CheckMode = 'none'; accfun2.CheckMode = 'none';
関数の例
関数 myUnsupportedFun
は、ランダム ノイズを生成し、それを入力に追加します。この関数は dlarray
オブジェクトではない乱数を生成するため、高速化をサポートしていません。
function out = myUnsupportedFun(dlX) sz = size(dlX); noise = rand(sz); out = dlX + noise; end
関数 mySupportedFun
は、トレースされた dlarray
オブジェクトで 'like'
オプションを使用してノイズを生成し、そのノイズを入力に追加します。
function out = mySupportedFun(dlX) sz = size(dlX); noise = rand(sz,'like',dlX); out = dlX + noise; end
参考
dlaccelerate
| AcceleratedFunction
| clearCache
| dlarray
| dlgradient
| dlfeval