gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta > Class Template Reference
[Gemm]

Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...

#include <gemm.hpp>

List of all members.

Static Public Member Functions

template<typename eT >
static arma_hot void apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0))

Detailed Description

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
class gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta >

Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).

Definition at line 27 of file gemm.hpp.


Member Function Documentation

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
template<typename eT >
static arma_hot void gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply ( Mat< eT > &  C,
const Mat< eT > &  A,
const Mat< eT > &  B,
const eT  alpha = eT(1),
const eT  beta = eT(0) 
) [inline, static]

Definition at line 37 of file gemm.hpp.

References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().

00044     {
00045     arma_extra_debug_sigprint();
00046 
00047     const u32 A_n_rows = A.n_rows;
00048     const u32 A_n_cols = A.n_cols;
00049     
00050     const u32 B_n_rows = B.n_rows;
00051     const u32 B_n_cols = B.n_cols;
00052     
00053     if( (do_trans_A == false) && (do_trans_B == false) )
00054       {
00055       arma_aligned podarray<eT> tmp(A_n_cols);
00056       eT* A_rowdata = tmp.memptr();
00057       
00058       for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00059         {
00060         
00061         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00062           {
00063           A_rowdata[col_A] = A.at(row_A,col_A);
00064           }
00065         
00066         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00067           {
00068           const eT* B_coldata = B.colptr(col_B);
00069           
00070           eT acc = eT(0);
00071           for(u32 i=0; i < B_n_rows; ++i)
00072             {
00073             acc += A_rowdata[i] * B_coldata[i];
00074             }
00075         
00076           if( (use_alpha == false) && (use_beta == false) )
00077             {
00078             C.at(row_A,col_B) = acc;
00079             }
00080           else
00081           if( (use_alpha == true) && (use_beta == false) )
00082             {
00083             C.at(row_A,col_B) = alpha * acc;
00084             }
00085           else
00086           if( (use_alpha == false) && (use_beta == true) )
00087             {
00088             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00089             }
00090           else
00091           if( (use_alpha == true) && (use_beta == true) )
00092             {
00093             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00094             }
00095           
00096           }
00097         }
00098       }
00099     else
00100     if( (do_trans_A == true) && (do_trans_B == false) )
00101       {
00102       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00103         {
00104         // col_A is interpreted as row_A when storing the results in matrix C
00105         
00106         const eT* A_coldata = A.colptr(col_A);
00107         
00108         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00109           {
00110           const eT* B_coldata = B.colptr(col_B);
00111           
00112           eT acc = eT(0);
00113           for(u32 i=0; i < B_n_rows; ++i)
00114             {
00115             acc += A_coldata[i] * B_coldata[i];
00116             }
00117         
00118           if( (use_alpha == false) && (use_beta == false) )
00119             {
00120             C.at(col_A,col_B) = acc;
00121             }
00122           else
00123           if( (use_alpha == true) && (use_beta == false) )
00124             {
00125             C.at(col_A,col_B) = alpha * acc;
00126             }
00127           else
00128           if( (use_alpha == false) && (use_beta == true) )
00129             {
00130             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00131             }
00132           else
00133           if( (use_alpha == true) && (use_beta == true) )
00134             {
00135             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00136             }
00137           
00138           }
00139         }
00140       }
00141     else
00142     if( (do_trans_A == false) && (do_trans_B == true) )
00143       {
00144       Mat<eT> B_tmp = trans(B);
00145       gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00146       }
00147     else
00148     if( (do_trans_A == true) && (do_trans_B == true) )
00149       {
00150       // mat B_tmp = trans(B);
00151       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00152       
00153       
00154       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00155       // transpose operations are not needed
00156       
00157       arma_aligned podarray<eT> tmp(B.n_cols);
00158       eT* B_rowdata = tmp.memptr();
00159       
00160       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00161         {
00162         
00163         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00164           {
00165           B_rowdata[col_B] = B.at(row_B,col_B);
00166           }
00167         
00168         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00169           {
00170           const eT* A_coldata = A.colptr(col_A);
00171           
00172           eT acc = eT(0);
00173           for(u32 i=0; i < A_n_rows; ++i)
00174             {
00175             acc += B_rowdata[i] * A_coldata[i];
00176             }
00177         
00178           if( (use_alpha == false) && (use_beta == false) )
00179             {
00180             C.at(col_A,row_B) = acc;
00181             }
00182           else
00183           if( (use_alpha == true) && (use_beta == false) )
00184             {
00185             C.at(col_A,row_B) = alpha * acc;
00186             }
00187           else
00188           if( (use_alpha == false) && (use_beta == true) )
00189             {
00190             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00191             }
00192           else
00193           if( (use_alpha == true) && (use_beta == true) )
00194             {
00195             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00196             }
00197           
00198           }
00199         }
00200       
00201       }
00202     }