Commit 50bd7b18 authored by Edgar Solomonik's avatar Edgar Solomonik
Browse files

Got CSR format working

......@@ -43,7 +43,7 @@ double mp3(Tensor<> & Ea,
return MP3_energy;
}
int sparse_mp3(int nv, int no, World & dw){
int sparse_mp3(int nv, int no, World & dw, double sp=.8){
int vvvv[] = {nv,nv,nv,nv};
int vovo[] = {nv,no,nv,no};
int vvoo[] = {nv,nv,no,no};
......@@ -76,7 +76,7 @@ int sparse_mp3(int nv, int no, World & dw){
Vijkl.fill_random(-1.0,1.0);
Vaibj.fill_random(-1.0,1.0);
Transform<> fltr([](double & d){ if (fabs(d)<.8) d=0.0; });
Transform<> fltr([=](double & d){ if (fabs(d)<sp) d=0.0; });
fltr(Vabij["abij"]);
fltr(Vijab["ijab"]);
fltr(Vabcd["abcd"]);
......@@ -88,7 +88,10 @@ int sparse_mp3(int nv, int no, World & dw){
#ifndef TEST_SUITE
double time = MPI_Wtime();
#endif
Timer_epoch dmp3("dense MP3");
dmp3.begin();
dense_energy = mp3(Ea, Ei, Fab, Fij, Vabij, Vijab, Vabcd, Vijkl, Vaibj);
dmp3.end();
#ifndef TEST_SUITE
if (dw.rank == 0)
printf("Calcluated MP3 energy %lf with dense integral tensors in time %lf sec.\n",dense_energy,MPI_Wtime()-time);
......@@ -103,7 +106,10 @@ int sparse_mp3(int nv, int no, World & dw){
#ifndef TEST_SUITE
time = MPI_Wtime();
#endif
Timer_epoch smp3("sparse MP3");
smp3.begin();
sparse_energy = mp3(Ea, Ei, Fab, Fij, Vabij, Vijab, Vabcd, Vijkl, Vaibj);
smp3.end();
#ifndef TEST_SUITE
if (dw.rank == 0)
printf("Calcluated MP3 energy %lf with sparse integral tensors in time %lf sec.\n",sparse_energy,MPI_Wtime()-time);
......@@ -135,6 +141,7 @@ char* getCmdOption(char ** begin,
int main(int argc, char ** argv){
int rank, np, nv, no, pass;
double sp;
int const in_num = argc;
char ** input_str = argv;
......@@ -152,13 +159,18 @@ int main(int argc, char ** argv){
if (no < 0) no = 7;
} else no = 7;
if (getCmdOption(input_str, input_str+in_num, "-sp")){
sp = atof(getCmdOption(input_str, input_str+in_num, "-sp"));
if (sp < 0.0 || sp > 1.0) sp = .8;
} else sp = .8;
if (rank == 0){
printf("Running sparse third-order Moller-Plesset petrubation theory (MP3) method on %d virtual and %d occupied orbitals\n",nv,no);
printf("Running sparse (%lf zeros) third-order Moller-Plesset petrubation theory (MP3) method on %d virtual and %d occupied orbitals\n",sp,nv,no);
}
{
World dw;
pass = sparse_mp3(nv, no, dw);
pass = sparse_mp3(nv, no, dw, sp);
assert(pass);
}
......
......@@ -10,7 +10,8 @@ using namespace CTF;
int spmm(int n,
int k,
World & dw){
World & dw,
double sp=.50){
Matrix<> spA(n, n, SP, dw);
Matrix<> dnA(n, n, dw);
......@@ -25,7 +26,7 @@ int spmm(int n,
dnA.fill_random(0.0,1.0);
spA["ij"] += dnA["ij"];
spA.sparsify(.5);
spA.sparsify(sp);
dnA["ij"] = 0.0;
dnA["ij"] += spA["ij"];
......@@ -78,6 +79,7 @@ char* getCmdOption(char ** begin,
int main(int argc, char ** argv){
int rank, np, n, k, pass;
double sp;
int const in_num = argc;
char ** input_str = argv;
......@@ -94,14 +96,19 @@ int main(int argc, char ** argv){
k = atoi(getCmdOption(input_str, input_str+in_num, "-k"));
if (k < 0) k = 7;
} else k = 7;
if (getCmdOption(input_str, input_str+in_num, "-sp")){
sp = atof(getCmdOption(input_str, input_str+in_num, "-sp"));
if (sp < 0.0 || sp > 1.0) sp = .8;
} else sp = .8;
{
World dw(argc, argv);
if (rank == 0){
printf("Multiplying %d-by-%d sparse matrix by %d-by-%d dense matrix\n",n,n,n,k);
printf("Multiplying %d-by-%d sparse (%lf zeros) matrix by %d-by-%d dense matrix\n",n,n,sp,n,k);
}
pass = spmm(n, k, dw);
pass = spmm(n, k, dw, sp);
assert(pass);
}
......
......@@ -16,6 +16,7 @@
#include "../redistribution/nosym_transp.h"
#include "../redistribution/redist.h"
#include "../sparse_formats/coo.h"
#include "../sparse_formats/csr.h"
#include <cfloat>
#include <limits>
......@@ -699,7 +700,10 @@ namespace CTF_int {
int64_t new_sz_A = 0;
A->rec_tsr->nnz_blk = (int64_t*)alloc(nvirt_A*sizeof(int64_t));
for (i=0; i<nvirt_A; i++){
A->rec_tsr->nnz_blk[i] = get_coo_size(A->nnz_blk[i], A->sr->el_size);
if (A->sr->has_csrmm)
A->rec_tsr->nnz_blk[i] = get_csr_size(A->nnz_blk[i], iprm.m, A->sr->el_size);
else
A->rec_tsr->nnz_blk[i] = get_coo_size(A->nnz_blk[i], A->sr->el_size);
new_sz_A += A->rec_tsr->nnz_blk[i];
}
A->rec_tsr->data = (char*)alloc(new_sz_A);
......@@ -715,8 +719,15 @@ namespace CTF_int {
char * data_ptr_out = A->rec_tsr->data;
char const * data_ptr_in = A->data;
for (i=0; i<nvirt_A; i++){
COO_Matrix cm(data_ptr_out);
cm.set_data(A->nnz_blk[i], A->order, A->lens, A->inner_ordering, nrow_idx, data_ptr_in, A->sr, phase);
if (A->sr->has_csrmm){
COO_Matrix cm(A->nnz_blk[i], A->sr);
cm.set_data(A->nnz_blk[i], A->order, A->lens, A->inner_ordering, nrow_idx, data_ptr_in, A->sr, phase);
CSR_Matrix cs(cm, iprm.m, A->sr, data_ptr_out);
cdealloc(cm.all_data);
} else {
COO_Matrix cm(data_ptr_out);
cm.set_data(A->nnz_blk[i], A->order, A->lens, A->inner_ordering, nrow_idx, data_ptr_in, A->sr, phase);
}
data_ptr_in += A->nnz_blk[i]*A->sr->pair_size();
data_ptr_out += A->rec_tsr->nnz_blk[i];
}
......@@ -3173,7 +3184,9 @@ namespace CTF_int {
} else
CTF_int::cdealloc(virt_dim);
seq_tsr_spctr * ctrseq = new seq_tsr_spctr(this, is_inner, inner_params, virt_blk_len_A, virt_blk_len_B, virt_blk_len_C, vrt_sz_C);
int krnl_type = is_inner;
if (krnl_type == 1 && A->sr->has_csrmm) krnl_type = 2;
seq_tsr_spctr * ctrseq = new seq_tsr_spctr(this, krnl_type, inner_params, virt_blk_len_A, virt_blk_len_B, virt_blk_len_C, vrt_sz_C);
if (is_top) {
hctr = ctrseq;
is_top = 0;
......@@ -3461,6 +3474,11 @@ namespace CTF_int {
#endif
// stat = zero_out_padding(type->tid_A);
// stat = zero_out_padding(type->tid_B);
#ifdef PROFILE
TAU_FSTART(pre_ctr_func_barrier);
MPI_Barrier(global_comm.cm);
TAU_FSTOP(pre_ctr_func_barrier);
#endif
TAU_FSTART(ctr_func);
/* Invoke the contraction algorithm */
A->topo->activate();
......@@ -3520,6 +3538,11 @@ namespace CTF_int {
ctrf->run(A->data, B->data, C->data);
A->topo->deactivate();
#ifdef PROFILE
TAU_FSTART(post_ctr_func_barrier);
MPI_Barrier(global_comm.cm);
TAU_FSTOP(post_ctr_func_barrier);
#endif
TAU_FSTOP(ctr_func);
#ifndef SEQ
if (C->is_cyclic)
......
......@@ -337,10 +337,13 @@ namespace CTF_int {
}
TAU_FSTOP(spctr_2d_general);
rec_ctr->run(op_A, new_nblk_A, new_size_blk_A,
op_B, nblk_B, size_blk_B,
op_C, nblk_C, size_blk_C,
op_C);
TAU_FSTART(spctr_2d_general);
new_C = C;
/*for (int i=0; i<ctr_sub_lda_C*ctr_lda_C; i++){
printf("[%d] P%d op_C[%d] = %lf\n",ctr_lda_C,idx_lyr,i, ((double*)op_C)[i]);
......
......@@ -174,7 +174,7 @@ namespace CTF_int {
char * C, int nblk_C, int64_t * size_blk_C,
char *& new_C){
int arank, brank, crank, i;
TAU_FSTART(spctr_replicate);
arank = 0, brank = 0, crank = 0;
for (i=0; i<ncdt_A; i++)
arank += cdt_A[i]->rank;
......@@ -252,10 +252,12 @@ namespace CTF_int {
rec_ctr->num_lyr = this->num_lyr;
rec_ctr->idx_lyr = this->idx_lyr;
TAU_FSTOP(spctr_replicate);
rec_ctr->run(buf_A, nblk_A, new_size_blk_A,
buf_B, nblk_B, new_size_blk_B,
C, nblk_C, size_blk_C,
new_C);
TAU_FSTART(spctr_replicate);
/*for (i=0; i<size_C; i++){
printf("P%d C[%d] = %lf\n",crank,i, ((double*)C)[i]);
}*/
......@@ -273,5 +275,6 @@ namespace CTF_int {
if (!is_sparse_B && brank != 0){
this->sr_B->set(B, this->sr_B->addid(), size_B);
}
TAU_FSTOP(spctr_replicate);
}
}
......@@ -6,6 +6,7 @@
#include "sp_seq_ctr.h"
#include "contraction.h"
#include "../sparse_formats/coo.h"
#include "../sparse_formats/csr.h"
#include "../tensor/untyped_tensor.h"
namespace CTF_int {
......@@ -27,7 +28,7 @@ namespace CTF_int {
seq_tsr_spctr::seq_tsr_spctr(contraction const * c,
bool use_coomm_,
int krnl_type_,
iparam const * inner_params,
int * virt_blk_len_A,
int * virt_blk_len_B,
......@@ -43,8 +44,8 @@ namespace CTF_int {
CTF_int::alloc_ptr(sizeof(int)*c->C->order, (void**)&new_sym_C);
memcpy(new_sym_C, c->C->sym, sizeof(int)*c->C->order);
this->use_coomm = use_coomm_;
if (use_coomm){
this->krnl_type = krnl_type_;
if (krnl_type > 0){
if (c->A->wrld->cdt.rank == 0){
DPRINTF(1,"Folded tensor n=%d m=%d k=%d\n", inner_params->n,
inner_params->m, inner_params->k);
......@@ -89,8 +90,8 @@ namespace CTF_int {
for (i=0; i<order_C; i++){
printf("edge_len_C[%d]=%d\n",i,edge_len_C[i]);
}
printf("is inner = %d\n", use_coomm);
if (use_coomm) printf("inner n = %d m= %d k = %d\n",
printf("kernel type is %d\n", krnl_type);
if (krnl_type>0) printf("inner n = %d m= %d k = %d\n",
inner_params.n, inner_params.m, inner_params.k);
}
......@@ -119,7 +120,7 @@ namespace CTF_int {
edge_len_C = (int*)CTF_int::alloc(sizeof(int)*order_C);
memcpy(edge_len_C, o->edge_len_C, sizeof(int)*order_C);
use_coomm = o->use_coomm;
krnl_type = o->krnl_type;
inner_params = o->inner_params;
is_custom = o->is_custom;
func = o->func;
......@@ -136,9 +137,9 @@ namespace CTF_int {
uint64_t size_A = sy_packed_size(order_A, edge_len_A, sym_A)*sr_A->el_size;
uint64_t size_B = sy_packed_size(order_B, edge_len_B, sym_B)*sr_B->el_size;
uint64_t size_C = sy_packed_size(order_C, edge_len_C, sym_C)*sr_C->el_size;
if (use_coomm) size_A *= inner_params.m*inner_params.k;
if (use_coomm) size_B *= inner_params.n*inner_params.k;
if (use_coomm) size_C *= inner_params.m*inner_params.n;
if (krnl_type>0) size_A *= inner_params.m*inner_params.k;
if (krnl_type>0) size_B *= inner_params.n*inner_params.k;
if (krnl_type>0) size_C *= inner_params.m*inner_params.n;
/*if (is_sparse_A) size_A = nnz_A*sr_A->pair_size();
if (is_sparse_B) size_B = nnz_B*sr_B->pair_size();
if (is_sparse_C) size_C = nnz_C*sr_C->pair_size();*/
......@@ -154,7 +155,7 @@ namespace CTF_int {
&idx_max, &rev_idx_map);
double flops = 2.0;
if (use_coomm) {
if (krnl_type>0) {
flops *= inner_params.m;
flops *= inner_params.n;
flops *= inner_params.k;
......@@ -184,7 +185,23 @@ namespace CTF_int {
ASSERT(!is_sparse_C);
ASSERT(nblk_A == 1);
if (use_coomm){
if (krnl_type==2){
// Do mm using CSR format
CSR_Matrix cA(A);
if (!sr_C->isequal(beta,sr_C->mulid())){
if (sr_C->isequal(beta,sr_C->addid())){
sr_C->set(C, beta, inner_params.sz_C);
} else {
sr_C->scal(inner_params.sz_C, beta, C, 1);
}
}
TAU_FSTART(CSRMM);
cA.csrmm(sr_A, inner_params.m, inner_params.n, inner_params.k,
alpha, B, sr_B, sr_C->mulid(), C, sr_C, func);
TAU_FSTOP(CSRMM);
} else if (krnl_type==1){
// Do mm using coordinate format
COO_Matrix cA(A);
if (!sr_C->isequal(beta,sr_C->mulid())){
if (sr_C->isequal(beta,sr_C->addid())){
......@@ -193,13 +210,16 @@ namespace CTF_int {
sr_C->scal(inner_params.sz_C, beta, C, 1);
}
}
TAU_FSTART(COOMM);
cA.coomm(sr_A, inner_params.m, inner_params.n, inner_params.k,
alpha, B, sr_B, sr_C->mulid(), C, sr_C, func);
TAU_FSTOP(COOMM);
} else {
ASSERT(size_blk_A[0]%sr_A->pair_size() == 0);
int64_t nnz_A = size_blk_A[0]/sr_A->pair_size();
TAU_FSTART(spA_dnB_dnC_seq);
spA_dnB_dnC_seq_ctr(this->alpha,
A,
nnz_A,
......@@ -222,6 +242,7 @@ namespace CTF_int {
sym_C,
idx_map_C,
func);
TAU_FSTOP(spA_dnB_dnC_seq);
}
}
......@@ -595,6 +616,7 @@ namespace CTF_int {
char * B, int nblk_B, int64_t const * size_blk_B,
char * C, int nblk_C, int64_t * size_blk_C,
char *& new_C){
TAU_FSTART(spctr_pin_keys);
char * X;
algstrct const * sr;
int64_t nnz = 0;
......@@ -651,10 +673,12 @@ namespace CTF_int {
pi.pin(nnz, order, lens, divisor, pi_new);
TAU_FSTOP(spctr_pin_keys);
rec_ctr->run(nA, nblk_A, size_blk_A,
nB, nblk_B, size_blk_B,
nC, nblk_C, size_blk_C,
new_C);
TAU_FSTART(spctr_pin_keys);
switch (AxBxC){
......@@ -671,5 +695,6 @@ namespace CTF_int {
depin(sr_C, order, lens, divisor, nblk_C, virt_dim, phys_rank, new_C, new_nnz_C, size_blk_C, new_C, true);
break;
}
TAU_FSTOP(spctr_pin_keys);
}
}
......@@ -43,7 +43,7 @@ namespace CTF_int{
int const * idx_map_C;
int * sym_C;
int use_coomm;
int krnl_type;
iparam inner_params;
int is_custom;
......@@ -82,7 +82,7 @@ namespace CTF_int{
}
seq_tsr_spctr(contraction const * s,
bool use_coomm,
int krnl_type,
iparam const * inner_params,
int * virt_blk_len_A,
int * virt_blk_len_B,
......
......@@ -69,7 +69,20 @@ namespace CTF_int {
CTF_BLAS::ZSCAL(&n,&alpha,X,&incX);
}
#if USE_SP_MKL
#define DEF_COOMM_KERNEL() \
for (int j=0; j<n; j++){ \
for (int i=0; i<m; i++){ \
C[j*m+i] *= beta; \
} \
} \
for (int i=0; i<nnz_A; i++){ \
int row_A = rows_A[i]-1; \
int col_A = cols_A[i]-1; \
for (int col_C=0; col_C<n; col_C++){ \
C[col_C*m+row_A] += alpha*A[i]*B[col_C*k+col_A]; \
} \
}
template <>
void default_coomm< float >
(int m,
......@@ -83,13 +96,16 @@ namespace CTF_int {
float const * B,
float beta,
float * C){
#if USE_SP_MKL
char transa = 'N';
char matdescra[6] = {'G',0,0,'F',0,0};
CTF_BLAS::MKL_SCOOMM(&transa, &m, &n, &k, &alpha,
matdescra, (float*)A, rows_A, cols_A, &nnz_A,
(float*)B, &k, &beta,
(float*)C, &m);
#else
DEF_COOMM_KERNEL();
#endif
}
template <>
......@@ -105,13 +121,18 @@ namespace CTF_int {
double const * B,
double beta,
double * C){
#if USE_SP_MKL
char transa = 'N';
char matdescra[6] = {'G',0,0,'F',0,0};
TAU_FSTART(MKL_DCOOMM);
CTF_BLAS::MKL_DCOOMM(&transa, &m, &n, &k, &alpha,
matdescra, (double*)A, rows_A, cols_A, &nnz_A,
(double*)B, &k, &beta,
(double*)C, &m);
TAU_FSTOP(MKL_DCOOMM);
#else
DEF_COOMM_KERNEL();
#endif
}
......@@ -128,13 +149,16 @@ namespace CTF_int {
std::complex<float> const * B,
std::complex<float> beta,
std::complex<float> * C){
#if USE_SP_MKL
char transa = 'N';
char matdescra[6] = {'G',0,0,'F',0,0};
CTF_BLAS::MKL_CCOOMM(&transa, &m, &n, &k, &alpha,
matdescra, (std::complex<float>*)A, rows_A, cols_A, &nnz_A,
(std::complex<float>*)B, &k, &beta,
(std::complex<float>*)C, &m);
#else
DEF_COOMM_KERNEL();
#endif
}
template <>
......@@ -150,16 +174,255 @@ namespace CTF_int {
std::complex<double> const * B,
std::complex<double> beta,
std::complex<double> * C){
#if USE_SP_MKL
char transa = 'N';
char matdescra[6] = {'G',0,0,'F',0,0};
CTF_BLAS::MKL_ZCOOMM(&transa, &m, &n, &k, &alpha,
matdescra, (std::complex<double>*)A, rows_A, cols_A, &nnz_A,
(std::complex<double>*)B, &k, &beta,
(std::complex<double>*)C, &m);
#else
DEF_COOMM_KERNEL();
#endif
}
#if USE_SP_MKL
template <>
bool get_def_has_csrmm<float>(){ return true; }
template <>
bool get_def_has_csrmm<double>(){ return true; }
template <>
bool get_def_has_csrmm< std::complex<float> >(){ return true; }
template <>
bool get_def_has_csrmm< std::complex<double> >(){ return true; }
#else
template <>
bool get_def_has_csrmm<float>(){ return false; }
template <>
bool get_def_has_csrmm<double>(){ return false; }
template <>
bool get_def_has_csrmm< std::complex<float> >(){ return false; }
template <>
bool get_def_has_csrmm< std::complex<double> >(){ return false; }
#endif
#if USE_SP_MKL
template <>
void def_coo_to_csr<float>(int64_t nz, int nrow, float * csr_vs, int * csr_cs, int * csr_rs, float const * coo_vs, int const * coo_rs, int const * coo_cs){
int inz = nz;
int job[8]={2,1,1,0,inz,0,0,0};
int info = 1;
CTF_BLAS::MKL_SCSRCOO(job, &nrow, csr_vs, csr_cs, csr_rs, &inz, (float*)coo_vs, coo_rs, coo_cs, &info);
}
template <>
void def_coo_to_csr<double>(int64_t nz, int nrow, double * csr_vs, int * csr_cs, int * csr_rs, double const * coo_vs, int const * coo_rs, int const * coo_cs){
int inz = nz;
int job[8]={2,1,1,0,inz,0,0,0};
int info = 1;
TAU_FSTART(MKL_DCSRCOO);
CTF_BLAS::MKL_DCSRCOO(job, &nrow, csr_vs, csr_cs, csr_rs, &inz, (double*)coo_vs, coo_rs, coo_cs, &info);
TAU_FSTOP(MKL_DCSRCOO);
}
template <>
void def_coo_to_csr<std::complex<float>>(int64_t nz, int nrow, std::complex<float> * csr_vs, int * csr_cs, int * csr_rs, std::complex<float> const * coo_vs, int const * coo_rs, int const * coo_cs){
int inz = nz;
int job[8]={2,1,1,0,inz,0,0,0};
int info = 1;
CTF_BLAS::MKL_CCSRCOO(job, &nrow, csr_vs, csr_cs, csr_rs, &inz, (std::complex<float>*)coo_vs, coo_rs, coo_cs, &info);
}
template <>
void def_coo_to_csr<std::complex<double>>(int64_t nz, int nrow, std::complex<double> * csr_vs, int * csr_cs, int * csr_rs, std::complex<double> const * coo_vs, int const * coo_rs, int const * coo_cs){
int inz = nz;
int job[8]={2,1,1,0,inz,0,0,0};
int info = 1;
CTF_BLAS::MKL_ZCSRCOO(job, &nrow, csr_vs, csr_cs, csr_rs, &inz, (std::complex<double>*)coo_vs, coo_rs, coo_cs, &info);
}
#else
template <>
void def_coo_to_csr<float>(int64_t nz, int nrow, float * csr_vs, int * csr_cs, int * csr_rs, float const * coo_vs, int const * coo_rs, int const * coo_cs){
printf("CTF ERROR: MKL required for COO to CSR conversion, should not be here\n");
ASSERT(0);
}
template <>
void def_coo_to_csr<double>(int64_t nz, int nrow, double * csr_vs, int * csr_cs, int * csr_rs, double const * coo_vs, int const * coo_rs, int const * coo_cs){
printf("CTF ERROR: MKL required for COO to CSR conversion, should not be here\n");
ASSERT(0);
}
template <>
void def_coo_to_csr<std::complex<float>>(int64_t nz, int nrow, std::complex<float> * csr_vs, int * csr_cs, int * csr_rs, std::complex<float> const * coo_vs, int const * coo_rs, int const * coo_cs){
printf("CTF ERROR: MKL required for COO to CSR conversion, should not be here\n");
ASSERT(0);
}
template <>
void def_coo_to_csr<std::complex<double>>(int64_t nz, int nrow, std::complex<double> * csr_vs, int * csr_cs, int * csr_rs, std::complex<double> const * coo_vs, int const * coo_rs, int const * coo_cs){
printf("CTF ERROR: MKL required for COO to CSR conversion, should not be here\n");
ASSERT(0);
}
#endif
#if USE_SP_MKL
template <>
void default_csrmm< float >
(int m,
int n,
int k,
float alpha,
float const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
float const * B,
float beta,
float * C){
char transa = 'N';
char matdescra[6] = {'G',0,0,'F',0,0};
CTF_BLAS::MKL_SCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, cols_A, rows_A, rows_A+1, B, &k, &beta, C, &m);
}
template <>
void default_csrmm< double >
(int m,
int n,
int k,
double alpha,
double const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
double const * B,
double beta,
double * C){
char transa = 'N';
char matdescra[6] = {'G',0,0,'F',0,0};
TAU_FSTART(MKL_DCSRMM);
CTF_BLAS::MKL_DCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, cols_A, rows_A, rows_A+1, B, &k, &beta, C, &m);
TAU_FSTOP(MKL_DCSRMM);
}
template <>
void default_csrmm< std::complex<float> >
(int m,
int n,
int k,
std::complex<float> alpha,
std::complex<float> const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
std::complex<float> const * B,
std::complex<float> beta,
std::complex<float> * C){
char transa = 'N';
char matdescra[6] = {'G',0,0,'F',0,0};
CTF_BLAS::MKL_CCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, cols_A, rows_A, rows_A+1, B, &k, &beta, C, &m);
}
template <>
void default_csrmm< std::complex<double> >
(int m,
int n,
int k,
std::complex<double> alpha,
std::complex<double> const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
std::complex<double> const * B,
std::complex<double> beta,
std::complex<double> * C){
char transa = 'N';
char matdescra[6] = {'G',0,0,'F',0,0};
CTF_BLAS::MKL_ZCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, cols_A, rows_A, rows_A+1, B, &k, &beta, C, &m);
}
#else
template <>
void default_csrmm< float >
(int m,
int n,
int k,
float alpha,
float const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
float const * B,
float beta,
float * C){
printf("CTF ERROR: MKL required for CSRMM, should not be here\n");
ASSERT(0);
}
template <>
void default_csrmm< double >
(int m,
int n,
int k,
double alpha,
double const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
double const * B,
double beta,
double * C){
printf("CTF ERROR: MKL required for CSRMM, should not be here\n");
ASSERT(0);
}
template <>
void default_csrmm< std::complex<float> >
(int m,
int n,
int k,
std::complex<float> alpha,
std::complex<float> const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
std::complex<float> const * B,
std::complex<float> beta,
std::complex<float> * C){
printf("CTF ERROR: MKL required for CSRMM, should not be here\n");
ASSERT(0);
}
template <>
void default_csrmm< std::complex<double> >
(int m,
int n,
int k,
std::complex<double> alpha,
std::complex<double> const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
std::complex<double> const * B,
std::complex<double> beta,
std::complex<double> * C){
printf("CTF ERROR: MKL required for CSRMM, should not be here\n");
ASSERT(0);
}
#endif
}
......@@ -208,7 +208,6 @@ namespace CTF_int {
}
}
#if USE_SP_MKL
template <>
void default_coomm< float >
(int m,
......@@ -264,7 +263,113 @@ namespace CTF_int {
std::complex<double> const * B,
std::complex<double> beta,
std::complex<double> * C);
#endif
template <typename dtype>
void default_csrmm
(int m,
int n,
int k,
dtype alpha,
dtype const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
dtype const * B,
dtype beta,
dtype * C){
printf("CTF ERROR: no default CSRMM, only possible for types supported by MKL\n");
ASSERT(0);
}
template <>
void default_csrmm< float >
(int m,
int n,
int k,
float alpha,
float const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
float const * B,
float beta,
float * C);
template <>
void default_csrmm< double >
(int m,
int n,
int k,
double alpha,
double const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
double const * B,
double beta,
double * C);
template <>
void default_csrmm< std::complex<float> >
(int m,
int n,
int k,
std::complex<float> alpha,
std::complex<float> const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
std::complex<float> const * B,
std::complex<float> beta,
std::complex<float> * C);
template <>
void default_csrmm< std::complex<double> >
(int m,
int n,
int k,
std::complex<double> alpha,
std::complex<double> const * A,
int const * rows_A,
int const * cols_A,
int nnz_A,
std::complex<double> const * B,
std::complex<double> beta,
std::complex<double> * C);
template <typename type>
bool get_def_has_csrmm(){ return false; }
template <>
bool get_def_has_csrmm<float>();
template <>
bool get_def_has_csrmm<double>();
template <>
bool get_def_has_csrmm< std::complex<float> >();
template <>
bool get_def_has_csrmm< std::complex<double> >();
template <typename dtype>
void def_coo_to_csr(int64_t nz, int nrow, dtype * csr_vs, int * csr_cs, int * csr_rs, dtype const * coo_vs, int const * coo_rs, int const * coo_cs){
printf("CTF ERROR: no default COO to CSR conversion kernel available, only possible for types supported by MKL\n");
ASSERT(0);
}
template <>
void def_coo_to_csr<float>(int64_t nz, int nrow, float * csr_vs, int * csr_cs, int * csr_rs, float const * coo_vs, int const * coo_rs, int const * coo_cs);
template <>
void def_coo_to_csr<double>(int64_t nz, int nrow, double * csr_vs, int * csr_cs, int * csr_rs, double const * coo_vs, int const * coo_rs, int const * coo_cs);
template <>
void def_coo_to_csr<std::complex<float>>(int64_t nz, int nrow, std::complex<float> * csr_vs, int * csr_cs, int * csr_rs, std::complex<float> const * coo_vs, int const * coo_rs, int const * coo_cs);
template <>
void def_coo_to_csr<std::complex<double>>(int64_t nz, int nrow, std::complex<double> * csr_vs, int * csr_cs, int * csr_rs, std::complex<double> const * coo_vs, int const * coo_rs, int const * coo_cs);
}
namespace CTF {
......@@ -340,6 +445,7 @@ namespace CTF {
faxpy = &CTF_int::default_axpy<dtype>;
fscal = &CTF_int::default_scal<dtype>;
fcoomm = &CTF_int::default_coomm<dtype>;
this->has_csrmm = CTF_int::get_def_has_csrmm<dtype>();
}
void mul(char const * a,
......@@ -478,8 +584,31 @@ namespace CTF {
}
} else { assert(0); }
}
};
void coo_to_csr(int64_t nz, int nrow, char * csr_vs, int * csr_cs, int * csr_rs, char const * coo_vs, int const * coo_rs, int const * coo_cs) const {
assert(this->has_csrmm);
CTF_int::def_coo_to_csr(nz, nrow, (dtype *)csr_vs, csr_cs, csr_rs, (dtype const *) coo_vs, coo_rs, coo_cs);
}
/** \brief sparse version of gemm using CSR format for A */
void csrmm(int m,
int n,
int k,
char const * alpha,
char const * A,
int const * rows_A,
int const * cols_A,
int64_t nnz_A,
char const * B,
char const * beta,
char * C,
CTF_int::bivar_function const * func) const {
assert(this->has_csrmm);
assert(func == NULL);
CTF_int::default_csrmm<dtype>(m,n,k,((dtype*)alpha)[0],(dtype*)A,rows_A,cols_A,nnz_A,(dtype*)B,((dtype*)beta)[0],(dtype*)C);
}
};
/**
* @}
*/
......
......@@ -207,7 +207,7 @@ namespace CTF {
if (!universe_exists){
universe_exists = true;
universe = *this;
CTF_int::mem_create();
// CTF_int::mem_create();
is_copy = true;
}
}
......
......@@ -21,6 +21,14 @@
#define MKL_DCOOMM mkl_dcoomm_
#define MKL_CCOOMM mkl_ccoomm_
#define MKL_ZCOOMM mkl_zcoomm_
#define MKL_SCSRCOO mkl_scsrcoo_
#define MKL_DCSRCOO mkl_dcsrcoo_
#define MKL_CCSRCOO mkl_ccsrcoo_
#define MKL_ZCSRCOO mkl_zcsrcoo_
#define MKL_SCSRMM mkl_scsrmm_
#define MKL_DCSRMM mkl_dcsrmm_
#define MKL_CCSRMM mkl_ccsrmm_
#define MKL_ZCSRMM mkl_zcsrmm_
#else
#define SGEMM sgemm
#define DGEMM dgemm
......@@ -41,6 +49,14 @@
#define MKL_DCOOMM mkl_dcoomm
#define MKL_CCOOMM mkl_ccoomm
#define MKL_ZCOOMM mkl_zcoomm
#define MKL_SCSRCOO mkl_scsrcoo
#define MKL_DCSRCOO mkl_dcsrcoo
#define MKL_CCSRCOO mkl_ccsrcoo
#define MKL_ZCSRCOO mkl_zcsrcoo
#define MKL_SCSRMM mkl_scsrmm
#define MKL_DCSRMM mkl_dcsrmm
#define MKL_CCSRMM mkl_ccsrmm
#define MKL_ZCSRMM mkl_zcsrmm
#endif
namespace CTF_BLAS {
......@@ -259,6 +275,68 @@ namespace CTF_BLAS {
std::complex<double> * c,
int * ldc);
extern "C"
void MKL_SCSRCOO(int const * job,
int * n,
float * acsr,
int const * ja,
int const * ia,
int * nnz,
float * acoo,
int const * rowind,
int const * colind,
int * info);
extern "C"
void MKL_DCSRCOO(int const * job,
int * n,
double * acsr,
int const * ja,
int const * ia,
int * nnz,
double * acoo,
int const * rowind,
int const * colind,
int * info);
extern "C"
void MKL_CCSRCOO(int const * job,
int * n,
std::complex<float> * acsr,
int const * ja,
int const * ia,
int * nnz,
std::complex<float> * acoo,
int const * rowind,
int const * colind,
int * info);
extern "C"
void MKL_ZCSRCOO(int const * job,
int * n,
std::complex<double> * acsr,
int const * ja,
int const * ia,
int * nnz,
std::complex<double> * acoo,
int const * rowind,
int const * colind,
int * info);
extern "C"
void MKL_SCSRMM(const char *transa , const int *m , const int *n , const int *k , const float *alpha , const char *matdescra , const float *val , const int *indx , const int *pntrb , const int *pntre , const float *b , const int *ldb , const float *beta , float *c , const int *ldc );
extern "C"
void MKL_DCSRMM(const char *transa , const int *m , const int *n , const int *k , const double *alpha , const char *matdescra , const double *val , const int *indx , const int *pntrb , const int *pntre , const double *b , const int *ldb , const double *beta , double *c , const int *ldc );
extern "C"
void MKL_CCSRMM(const char *transa , const int *m , const int *n , const int *k , const std::complex<float> *alpha , const char *matdescra , const std::complex<float> *val , const int *indx , const int *pntrb , const int *pntre , const std::complex<float> *b , const int *ldb , const std::complex<float> *beta , std::complex<float> *c , const int *ldc );
extern "C"
void MKL_ZCSRMM(const char *transa , const int *m , const int *n , const int *k , const std::complex<double> *alpha , const char *matdescra , const std::complex<double> *val , const int *indx , const int *pntrb , const int *pntre , const std::complex<double> *b , const int *ldb , const std::complex<double> *beta , std::complex<double> *c , const int *ldc );
#endif
......
......@@ -164,7 +164,7 @@ namespace CTF_int {
*/
void mem_exit(int rank){
instance_counter--;
assert(instance_counter >= 0);
//assert(instance_counter >= 0);
#ifndef PRODUCTION
if (instance_counter == 0){
for (int i=0; i<max_threads; i++){
......
LOBJS = coo.o
LOBJS = coo.o csr.o
OBJS = $(addprefix $(ODIR)/, $(LOBJS))
#%d | r ! grep -ho "\.\..*\.h" *.cxx *.h | sort | uniq
......
#include "coo.h"
#include "../shared/util.h"
namespace CTF_int {
int64_t get_coo_size(int64_t nnz, int val_size){
return nnz*(val_size+sizeof(int)*2)+2*sizeof(int64_t);
......@@ -20,8 +21,12 @@ namespace CTF_int {
return ((int64_t*)all_data)[0];
}
int COO_Matrix::val_size() const {
return ((int64_t*)all_data)[1];
}
int64_t COO_Matrix::size() const {
return nnz()*((int64_t*)all_data)[1];
return get_coo_size(nnz(),val_size());
}
char * COO_Matrix::vals() const {
......@@ -30,22 +35,23 @@ namespace CTF_int {
int * COO_Matrix::rows() const {
int64_t n = this->nnz();
int val_size = ((int64_t*)all_data)[1];
int v_sz = this->val_size();
return (int*)(all_data + n*val_size+2*sizeof(int64_t));
return (int*)(all_data + n*v_sz+2*sizeof(int64_t));
}
int * COO_Matrix::cols() const {
int64_t n = this->nnz();
int val_size = ((int64_t*)all_data)[1];
int v_sz = ((int64_t*)all_data)[1];
return (int*)(all_data + n*(val_size+sizeof(int))+2*sizeof(int64_t));
return (int*)(all_data + n*(v_sz+sizeof(int))+2*sizeof(int64_t));
}
void COO_Matrix::set_data(int64_t nz, int order, int const * lens, int const * rev_ordering, int nrow_idx, char const * tsr_data, algstrct const * sr, int const * phase){
TAU_FSTART(convert_to_COO);
((int64_t*)all_data)[0] = nz;
((int64_t*)all_data)[1] = sr->el_size;
int val_size = sr->el_size;
int v_sz = sr->el_size;
int * rev_ord_lens = (int*)alloc(sizeof(int)*order);
int * ordering = (int*)alloc(sizeof(int)*order);
......@@ -97,8 +103,9 @@ namespace CTF_int {
k=k/lens[j];
}
// printf("k=%ld col = %d row = %d\n", pi[i].k(), cs[i], rs[i]);
memcpy(vs+val_size*i, pi[i].d(), val_size);
memcpy(vs+v_sz*i, pi[i].d(), v_sz);
}
TAU_FSTOP(convert_to_COO);
}
......
......@@ -19,6 +19,8 @@ namespace CTF_int {
int64_t nnz() const;
int val_size() const;
int64_t size() const;
char * vals() const;
......
#include "csr.h"
#include "../shared/util.h"
namespace CTF_int {
int64_t get_csr_size(int64_t nnz, int nrow, int val_size){
return nnz*(val_size+sizeof(int))+(nrow+1)*sizeof(int)+3*sizeof(int64_t);
}
CSR_Matrix::CSR_Matrix(int64_t nnz, int nrow, algstrct const * sr){
int64_t size = get_csr_size(nnz, nrow, sr->el_size);
all_data = (char*)alloc(size);
((int64_t*)all_data)[0] = nnz;
((int64_t*)all_data)[1] = sr->el_size;
((int64_t*)all_data)[2] = nrow;
}
CSR_Matrix::CSR_Matrix(char * all_data_){
all_data = all_data_;
}
CSR_Matrix::CSR_Matrix(COO_Matrix const & coom, int nrow, algstrct const * sr, char * data){
int64_t nz = coom.nnz();
int64_t v_sz = coom.val_size();
int const * coo_rs = coom.rows();
int const * coo_cs = coom.cols();
char const * vs = coom.vals();
int64_t size = get_csr_size(nz, nrow, v_sz);
if (data == NULL)
all_data = (char*)alloc(size);
else
all_data = data;
((int64_t*)all_data)[0] = nz;
((int64_t*)all_data)[1] = v_sz;
((int64_t*)all_data)[2] = nrow;
char * csr_vs = vals();
int * csr_rs = rows();
int * csr_cs = cols();
//memcpy(csr_vs, vs, nz*v_sz);
//memset(csr_rs
sr->coo_to_csr(nz, nrow, csr_vs, csr_cs, csr_rs, vs, coo_rs, coo_cs);
}
int64_t CSR_Matrix::nnz() const {
return ((int64_t*)all_data)[0];
}
int CSR_Matrix::val_size() const {
return ((int64_t*)all_data)[1];
}
int64_t CSR_Matrix::size() const {
return get_csr_size(nnz(),nrow(),val_size());
}
int CSR_Matrix::nrow() const {
return ((int64_t*)all_data)[2];
}
char * CSR_Matrix::vals() const {
return all_data + 3*sizeof(int64_t);
}
int * CSR_Matrix::rows() const {
int64_t n = this->nnz();
int v_sz = this->val_size();
return (int*)(all_data + n*v_sz+3*sizeof(int64_t));
}
int * CSR_Matrix::cols() const {
int64_t n = this->nnz();
int64_t nr = this->nrow();
int v_sz = this->val_size();
return (int*)(all_data + n*v_sz+(nr+1)*sizeof(int)+3*sizeof(int64_t));
}
void CSR_Matrix::csrmm(algstrct const * sr_A, int m, int n, int k, char const * alpha, char const * B, algstrct const * sr_B, char const * beta, char * C, algstrct const * sr_C, bivar_function const * func){
int64_t nz = nnz();
int const * rs = rows();
int const * cs = cols();
char const * vs = vals();
ASSERT(sr_B->el_size == sr_A->el_size);
ASSERT(sr_C->el_size == sr_A->el_size);
sr_A->csrmm(m,n,k,alpha,vs,rs,cs,nz,B,beta,C,func);
}
}
#ifndef __CSR_H__
#define __CSR_H__
#include "../tensor/algstrct.h"
#include "coo.h"
namespace CTF_int {
class bivar_function;
int64_t get_csr_size(int64_t nnz, int nrow, int val_size);
class CSR_Matrix{
public:
char * all_data;
CSR_Matrix(int64_t nnz, int nrow, algstrct const * sr);
CSR_Matrix(char * all_data);
CSR_Matrix(COO_Matrix const & coom, int nrow, algstrct const * sr, char * data=NULL);
int64_t nnz() const;
int64_t size() const;
int nrow() const;
int val_size() const;
char * vals() const;
int * rows() const;
int * cols() const;
// void set_data(int64_t nz, int order, int const * lens, int const * ordering, int nrow_idx, char const * tsr_data, algstrct const * sr, int const * phase);
/**
* \brief computes C = beta*C + func(alpha*A*B) where A is this CSR_Matrix, while B and C are dense
*/
void csrmm(algstrct const * sr_A, int m, int n, int k, char const * alpha, char const * B, algstrct const * sr_B, char const * beta, char * C, algstrct const * sr_C, bivar_function const * func);
};
}
#endif
......@@ -109,6 +109,7 @@ namespace CTF_int {
}
algstrct::algstrct(int el_size_){
el_size = el_size_;
has_csrmm = false;
}
......@@ -247,6 +248,11 @@ namespace CTF_int {
}
return iseq;
}
void algstrct::coo_to_csr(int64_t nz, int nrow, char * csr_vs, int * csr_cs, int * csr_rs, char const * coo_vs, int const * coo_rs, int const * coo_cs) const {
printf("CTF ERROR: cannot convert elements of this algebraic structure to CSR\n");
ASSERT(0);
}
void algstrct::acc(char * b, char const * beta, char const * a, char const * alpha) const {
char tmp[el_size];
......@@ -415,6 +421,12 @@ namespace CTF_int {
printf("CTF ERROR: coomm not present for this algebraic structure\n");
ASSERT(0);
}
void algstrct::csrmm(int m, int n, int k, char const * alpha, char const * A, int const * rows_A, int const * cols_A, int64_t nnz_A, char const * B, char const * beta, char * C, bivar_function const * func) const {
printf("CTF ERROR: csrmm not present for this algebraic structure\n");
ASSERT(0);
}
ConstPairIterator::ConstPairIterator(PairIterator const & pi){
sr=pi.sr; ptr=pi.ptr;
......
......@@ -14,6 +14,8 @@ namespace CTF_int {
public:
/** \brief size of each element of algstrct in bytes */
int el_size;
/** \brief whether there is an MKL CSRMM routine for this algebraic structure */
bool has_csrmm;
/** \brief datatype for pairs, always custom create3d */
// MPI_Datatype pmdtype;
......@@ -28,7 +30,7 @@ namespace CTF_int {
/**
* \brief default constructor
*/
algstrct(){}
algstrct(){ has_csrmm = false; }
/**
* \brief copy constructor
......@@ -162,8 +164,25 @@ namespace CTF_int {
char * C,
bivar_function const * func) const;
/** \brief sparse version of gemm using CSR format for A */
virtual void csrmm(int m,
int n,
int k,
char const * alpha,
char const * A,
int const * rows_A,
int const * cols_A,
int64_t nnz_A,
char const * B,
char const * beta,
char * C,
bivar_function const * func) const;
/** \brief returns true if algstrct elements a and b are equal */
virtual bool isequal(char const * a, char const * b) const;
/** \brief converts coordinate sparse matrix layout to CSR layout */
virtual void coo_to_csr(int64_t nz, int nrow, char * csr_vs, int * csr_cs, int * csr_rs, char const * coo_vs, int const * coo_rs, int const * coo_cs) const;
/** \brief compute b=beta*b + alpha*a */
void acc(char * b, char const * beta, char const * a, char const * alpha) const;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment