Strassenアルゴリズムの分割統治

画像



こんにちは友人! 1つの悪名高い教育プロジェクトの生徒として、 bo_0mIはAdvanced Java Programmingコースの入門講義の後、最初の宿題を受け取りました。 行列を乗算するプログラムを実装する必要がありました。 そして大丈夫でしたが、ジョーカー会議が来週開催されることになったので、私たちの先生はこの機会にレッスンをキャンセルすることを決め、金曜日の夕方に数時間無料で過ごしました。 無駄に時間を無駄にしないでください! 誰も急いでいないので、あなたは創造的になることができます。



フードの下でようこそ↓



頭に浮かぶ最初のもの


おそらく、工科大学のすべての学生は行列を乗算する必要がありました。 アルゴリズムは常に1つ、つまり単純な3乗法でした。 そして、どのように聞こえても、この方法はそれほど悪くありません(行列の次元が100未満の場合)。



私たちは皆これから始めました。



for (int i = 0; i < A.rows(); i++) { for (int j = 0; j < B.columns(); j++) { for (int k = 0; k < A.columns(); k++) { C[i][j] += A[i][k] * B[k][j]; } } }
      
      





今後は、転置を使用した修正版を使用すると言います。 この変更については、 ここだけでなく、それだけはありません。



さて、さらに進んでみましょう!



Strassenアルゴリズム


おそらく誰もが知っているわけではありませんが、アルゴリズムの作成者であるVolker Strassenは生きているだけでなく、積極的に教えており、コンスタンツ大学数学統計学部の名誉教授でもあります。 少なくともwikiでこの人物について読んでください。

ウィキペディアの理論の一部:



AとBを2(n * n)-行列、nを2のべき乗とします。次に、各行列AとBを4((n / 2)*(n / 2))-matrixに分割し、それらを通して表現できます。行列AとBの積:


画像






新しい要素を定義します。


画像






したがって、再帰の各段階で必要な乗算は7回だけです。 行列Cの要素は、Pkから次の式で表されます。


画像






行列Ci、jのサイズが十分に小さくなるまで、再帰的なプロセスがn回続き、その後、行列乗算の通常の方法が使用されます。 これは、Strassenアルゴリズムが、追加の数が多いために小さな行列の通常のアルゴリズムと比較して効率が低下するためです。


練習に行こう!



Strassenアルゴリズムを実装するには、追加の関数が必要です。 上記のように、アルゴリズムは次元が2次に等しい正方行列でのみ機能するため、元の行列をこの形式にします。



このために、新しい次元を定義する関数が実装されました:



 private static int log2(int x) { int result = 1; while ((x >>= 1) != 0) result++; return result; } //****************************************************************************************** private static int getNewDimension(int[][] a, int[][] b) { return 1 << log2(Collections.max(Arrays.asList(a.length, a[0].length, b[0].length))); //  -  }
      
      





そして、マトリックスを目的のサイズに拡張する関数:



 private static int[][] addition2SquareMatrix(int[][] a, int n) { int[][] result = new int[n][n]; for (int i = 0; i < a.length; i++) { for (int j = 0; j < a[i].length; j++) { result[i][j] = a[i][j]; } } return result; }
      
      





これで、ソースマトリックスはStrassenアルゴリズムを実装するための要件を満たします。 また、サイズn * nのマトリックスを4つのマトリックス(n / 2)*(n / 2)に分割し、マトリックスを復元するための逆関数を必要とします。



 private static void splitMatrix(int[][] a, int[][] a11, int[][] a12, int[][] a21, int[][] a22) { int n = a.length >> 1; for (int i = 0; i < n; i++) { System.arraycopy(a[i], 0, a11[i], 0, n); System.arraycopy(a[i], n, a12[i], 0, n); System.arraycopy(a[i + n], 0, a21[i], 0, n); System.arraycopy(a[i + n], n, a22[i], 0, n); } } //****************************************************************************************** private static int[][] collectMatrix(int[][] a11, int[][] a12, int[][] a21, int[][] a22) { int n = a11.length; int[][] a = new int[n << 1][n << 1]; for (int i = 0; i < n; i++) { System.arraycopy(a11[i], 0, a[i], 0, n); System.arraycopy(a12[i], 0, a[i], n, n); System.arraycopy(a22[i], 0, a[i + n], n, n); } return a; }
      
      





最も興味深いことになりましたが、Strassenアルゴリズムによる行列乗算の主な機能は次のとおりです。



Strassenアルゴリズム
 private static int[][] multiStrassen(int[][] a, int[][] b, int n) { if (n <= 64) { return multiply(a, b); } n = n >> 1; int[][] a11 = new int[n][n]; int[][] a12 = new int[n][n]; int[][] a21 = new int[n][n]; int[][] a22 = new int[n][n]; int[][] b11 = new int[n][n]; int[][] b12 = new int[n][n]; int[][] b21 = new int[n][n]; int[][] b22 = new int[n][n]; splitMatrix(a, a11, a12, a21, a22); splitMatrix(b, b11, b12, b21, b22); int[][] p1 = multiStrassen(summation(a11, a22), summation(b11, b22), n); int[][] p2 = multiStrassen(summation(a21, a22), b11, n); int[][] p3 = multiStrassen(a11, subtraction(b12, b22), n); int[][] p4 = multiStrassen(a22, subtraction(b21, b11), n); int[][] p5 = multiStrassen(summation(a11, a12), b22, n); int[][] p6 = multiStrassen(subtraction(a21, a11), summation(b11, b12), n); int[][] p7 = multiStrassen(subtraction(a12, a22), summation(b21, b22), n); int[][] c11 = summation(summation(p1, p4), subtraction(p7, p5)); int[][] c12 = summation(p3, p5); int[][] c21 = summation(p2, p4); int[][] c22 = summation(subtraction(p1, p2), summation(p3, p6)); return collectMatrix(c11, c12, c21, c22); }
      
      







これで終わりかもしれません。 実装されたアルゴリズムは宿題が完了すると機能しますが、好奇心が強い人大人のパフォーマンスを切望します。 Java 7をご利用ください。



並列化する時です


Java 7は、再帰的なタスクを並列化するための優れたAPIを提供します。 そのリリースで、java.util.concurrentパッケージへの追加の1つが登場しました-Divide and Conquerパラダイムの実装-Fork-Join。 これは、タスクを再帰的にサブタスクに分割し、解決して、結果を結合するという考え方です。 この技術の詳細については、 ドキュメントをご覧ください。



このパラダイムをStrassenアルゴリズムにどれだけ簡単かつ効果的に適用できるかを見てみましょう。



Fork / Joinを使用したアルゴリズムの実装
 private static class myRecursiveTask extends RecursiveTask<int[][]> { private static final long serialVersionUID = -433764214304695286L; int n; int[][] a; int[][] b; public myRecursiveTask(int[][] a, int[][] b, int n) { this.a = a; this.b = b; this.n = n; } @Override protected int[][] compute() { if (n <= 64) { return multiply(a, b); } n = n >> 1; int[][] a11 = new int[n][n]; int[][] a12 = new int[n][n]; int[][] a21 = new int[n][n]; int[][] a22 = new int[n][n]; int[][] b11 = new int[n][n]; int[][] b12 = new int[n][n]; int[][] b21 = new int[n][n]; int[][] b22 = new int[n][n]; splitMatrix(a, a11, a12, a21, a22); splitMatrix(b, b11, b12, b21, b22); myRecursiveTask task_p1 = new myRecursiveTask(summation(a11,a22),summation(b11,b22),n); myRecursiveTask task_p2 = new myRecursiveTask(summation(a21,a22),b11,n); myRecursiveTask task_p3 = new myRecursiveTask(a11,subtraction(b12,b22),n); myRecursiveTask task_p4 = new myRecursiveTask(a22,subtraction(b21,b11),n); myRecursiveTask task_p5 = new myRecursiveTask(summation(a11,a12),b22,n); myRecursiveTask task_p6 = new myRecursiveTask(subtraction(a21,a11),summation(b11,b12),n); myRecursiveTask task_p7 = new myRecursiveTask(subtraction(a12,a22),summation(b21,b22),n); task_p1.fork(); task_p2.fork(); task_p3.fork(); task_p4.fork(); task_p5.fork(); task_p6.fork(); task_p7.fork(); int[][] p1 = task_p1.join(); int[][] p2 = task_p2.join(); int[][] p3 = task_p3.join(); int[][] p4 = task_p4.join(); int[][] p5 = task_p5.join(); int[][] p6 = task_p6.join(); int[][] p7 = task_p7.join(); int[][] c11 = summation(summation(p1, p4), subtraction(p7, p5)); int[][] c12 = summation(p3, p5); int[][] c21 = summation(p2, p4); int[][] c22 = summation(subtraction(p1, p2), summation(p3, p6)); return collectMatrix(c11, c12, c21, c22); } }
      
      







クライマックス


おそらく、実際のハードウェア上でアルゴリズムのパフォーマンスを比較することに熱心になっているでしょう。 すぐに正方行列のテストを実施することを規定します。 だから私たちは持っています:



  1. 従来の(キュービック)行列乗算法
  2. 転置を使用する従来の
  3. Strassenアルゴリズム
  4. 並列Strassenアルゴリズム


行列の次元は、間隔[100..4000]および100の増分で設定されます。



画像



予想どおり、最初のアルゴリズムはすぐに上位3つから外れました。 しかし、彼の近代化された兄弟(転置オプション)では、物事はそれほど単純ではありません。 かなり大きな次元であっても、このアルゴリズムは劣っているだけでなく、多くの場合、シングルスレッドStrassenアルゴリズムよりも優れています。 それでも、Fork-Join Frameworkの形式の切り札を持っているため、パフォーマンスを大幅に向上させることができました。 Strassenアルゴリズムの並列化により、乗算時間をほぼ3倍に短縮でき、最終的な合計をリードできました。



» ここにソースコードを掲載しました



私たちの仕事についてのフィードバックやコメントを歓迎します。 ご清聴ありがとうございました!



All Articles