Commit d86a9631 authored by Edgar Solomonik's avatar Edgar Solomonik
Browse files

reorganized sparse/dense summation construction and execution so they use...

reorganized sparse/dense summation construction and execution so they use separate objects/functions.
......@@ -118,7 +118,6 @@ namespace CTF_int {
* \brief deconstructor
*/
~strp_sum();
strp_sum(){}
strp_sum(summation const * s);
};
......
LOBJS = summation.o sym_seq_sum.o sum_tsr.o spr_seq_sum.o
LOBJS = summation.o sym_seq_sum.o sum_tsr.o spr_seq_sum.o spsum_tsr.o
OBJS = $(addprefix $(ODIR)/, $(LOBJS))
#%d | r ! grep -ho "\.\..*\.h" *.cxx *.h | sort | uniq
......
This diff is collapsed.
#ifndef __SPSUM_TSR_H__
#define __SPSUM_TSR_H__
#include "sum_tsr.h"
namespace CTF_int {
class tspsum : public tsum {
public:
bool is_sparse_A;
int64_t nnz_A;
int nvirt_A;
int64_t * nnz_blk_A;
bool is_sparse_B;
int64_t nnz_B;
int nvirt_B;
int64_t * nnz_blk_B;
int64_t new_nnz_B;
char * new_B;
~tspsum();
tspsum(tspsum * other);
virtual tspsum * clone() { return NULL; }
tspsum(summation const * s);
virtual void set_nnz_blk_A(int64_t const * nnbA){
if (nnbA != NULL) memcpy(nnz_blk_A, nnbA, nvirt_A*sizeof(int64_t));
}
};
class tspsum_virt : public tspsum {
public:
/* Class to be called on sub-blocks */
tspsum * rec_tsum;
int num_dim;
int * virt_dim;
int order_A;
int64_t blk_sz_A; //if dense
int const * idx_map_A;
int order_B;
int64_t blk_sz_B; //if dense
int const * idx_map_B;
void run();
void print();
int64_t mem_fp();
void set_nnz_blk_A(int64_t const * nnbA){
tspsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tspsum * clone();
/**
* \brief iterates over the dense virtualization block grid and contracts
*/
tspsum_virt(tspsum * other);
~tspsum_virt();
tspsum_virt(summation const * s);
};
/**
* \brief performs replication along a dimension, generates 2.5D algs
*/
class tspsum_replicate : public tspsum {
public:
int64_t size_A; /* size of A blocks */
int64_t size_B; /* size of B blocks */
int ncdt_A; /* number of processor dimensions to replicate A along */
int ncdt_B; /* number of processor dimensions to replicate B along */
CommData ** cdt_A;
CommData ** cdt_B;
/* Class to be called on sub-blocks */
tspsum * rec_tsum;
void run();
void print();
int64_t mem_fp();
tspsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tspsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tspsum_replicate(tspsum * other);
~tspsum_replicate();
tspsum_replicate(summation const * s,
int const * phys_mapped,
int64_t blk_sz_A,
int64_t blk_sz_B);
};
class seq_tsr_spsum : public tspsum {
public:
int order_A;
int * edge_len_A;
int const * idx_map_A;
int * sym_A;
int order_B;
int * edge_len_B;
int const * idx_map_B;
int * sym_B;
//fseq_tsr_sum func_ptr;
int is_inner;
int inr_stride;
int64_t map_pfx;
int is_custom;
univar_function const * func; //fseq_elm_sum custom_params;
/**
* \brief wraps user sequential function signature
*/
void run();
void print();
int64_t mem_fp();
tspsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tspsum::set_nnz_blk_A(nnbA);
}
/**
* \brief copies sum object
* \param[in] other object to copy
*/
seq_tsr_spsum(tspsum * other);
~seq_tsr_spsum(){ CTF_int::cdealloc(edge_len_A), CTF_int::cdealloc(edge_len_B),
CTF_int::cdealloc(sym_A), CTF_int::cdealloc(sym_B); };
seq_tsr_spsum(summation const * s);
};
class tspsum_map : public tspsum {
public:
tspsum * rec_tsum;
int nmap_idx;
int64_t * map_idx_len;
int64_t * map_idx_lda;
void run();
void print();
int64_t mem_fp();
tspsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tspsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tspsum_map(tspsum * other);
~tspsum_map();
tspsum_map(summation const * s);
};
class tspsum_permute : public tspsum {
public:
tspsum * rec_tsum;
bool A_or_B; //if false perm_B
int order;
int * lens_new;
int * lens_old; // FIXME = lens_new?
int * p;
void run();
void print();
int64_t mem_fp();
tspsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tspsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tspsum_permute(tspsum * other);
~tspsum_permute();
tspsum_permute(summation const * s, bool A_or_B, int const * lens);
};
class tspsum_pin_keys : public tspsum {
public:
tspsum * rec_tsum;
bool A_or_B;
int order;
int const * lens;
int * divisor;
int * virt_dim;
int * phys_rank;
void run();
void print();
int64_t mem_fp();
tspsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tspsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tspsum_pin_keys(tspsum * other);
~tspsum_pin_keys();
tspsum_pin_keys(summation const * s, bool A_or_B);
};
}
#endif
This diff is collapsed.
......@@ -73,17 +73,6 @@ namespace CTF_int {
char const * beta;
void * buffer;
bool is_sparse_A;
int64_t nnz_A;
int nvirt_A;
int64_t * nnz_blk_A;
bool is_sparse_B;
int64_t nnz_B;
int nvirt_B;
int64_t * nnz_blk_B;
int64_t new_nnz_B;
char * new_B;
virtual void run() {};
virtual void print() {};
// virtual int64_t calc_new_nnz_B() { return nnz_B; } //if sparse
......@@ -93,13 +82,8 @@ namespace CTF_int {
*/
virtual int64_t mem_fp() { return 0; };
virtual tsum * clone() { return NULL; };
virtual void set_nnz_blk_A(int64_t const * nnbA){
if (nnbA != NULL) memcpy(nnz_blk_A, nnbA, nvirt_A*sizeof(int64_t));
}
virtual ~tsum();
tsum(tsum * other);
tsum(){ buffer = NULL; }
tsum(summation const * s);
};
......@@ -120,10 +104,6 @@ namespace CTF_int {
void run();
void print();
int64_t mem_fp();
void set_nnz_blk_A(int64_t const * nnbA){
tsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tsum * clone();
/**
......@@ -131,7 +111,6 @@ namespace CTF_int {
*/
tsum_virt(tsum * other);
~tsum_virt();
tsum_virt(){}
tsum_virt(summation const * s);
};
......@@ -155,15 +134,13 @@ namespace CTF_int {
void print();
int64_t mem_fp();
tsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tsum_replicate(tsum * other);
~tsum_replicate();
tsum_replicate(){}
tsum_replicate(summation const * s);
tsum_replicate(summation const * s,
int const * phys_mapped,
int64_t blk_sz_A,
int64_t blk_sz_B);
};
class seq_tsr_sum : public tsum {
......@@ -193,9 +170,6 @@ namespace CTF_int {
void print();
int64_t mem_fp();
tsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tsum::set_nnz_blk_A(nnbA);
}
/**
* \brief copies sum object
......@@ -204,82 +178,10 @@ namespace CTF_int {
seq_tsr_sum(tsum * other);
~seq_tsr_sum(){ CTF_int::cdealloc(edge_len_A), CTF_int::cdealloc(edge_len_B),
CTF_int::cdealloc(sym_A), CTF_int::cdealloc(sym_B); };
seq_tsr_sum(){}
seq_tsr_sum(summation const * s);
};
class tsum_sp_map : public tsum {
public:
tsum * rec_tsum;
int nmap_idx;
int64_t * map_idx_len;
int64_t * map_idx_lda;
void run();
void print();
int64_t mem_fp();
tsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tsum_sp_map(tsum * other);
~tsum_sp_map();
tsum_sp_map(){}
tsum_sp_map(summation const * s);
};
class tsum_sp_permute : public tsum {
public:
tsum * rec_tsum;
bool A_or_B; //if false perm_B
int order;
int * lens_new;
int * lens_old; // FIXME = lens_new?
int * p;
void run();
void print();
int64_t mem_fp();
tsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tsum_sp_permute(tsum * other);
~tsum_sp_permute();
tsum_sp_permute(){}
tsum_sp_permute(summation const * s, bool A_or_B, int const * lens);
};
class tsum_sp_pin_keys : public tsum {
public:
tsum * rec_tsum;
bool A_or_B;
int order;
int const * lens;
int * divisor;
int * virt_dim;
int * phys_rank;
void run();
void print();
int64_t mem_fp();
tsum * clone();
void set_nnz_blk_A(int64_t const * nnbA){
tsum::set_nnz_blk_A(nnbA);
rec_tsum->set_nnz_blk_A(nnbA);
}
tsum_sp_pin_keys(tsum * other);
~tsum_sp_pin_keys();
tsum_sp_pin_keys(summation const * s, bool A_or_B);
};
/**
* \brief invert index map
* \param[in] order_A number of dimensions of A
......
......@@ -6,7 +6,6 @@
#include "../shared/util.h"
#include "../shared/memcontrol.h"
#include "sym_seq_sum.h"
#include "sum_tsr.h"
#include "../symmetry/sym_indices.h"
#include "../symmetry/symmetrization.h"
#include "../redistribution/nosym_transp.h"
......@@ -402,16 +401,17 @@ namespace CTF_int {
*new_ordering_B = ordering_B;
}
tsum * summation::construct_sum(int inner_stride){
int nvirt, i, iA, iB, order_tot, is_top, sA, sB, need_rep, i_A, i_B, j, k;
tspsum * summation::construct_sparse_sum(int const * phys_mapped){
int nvirt, i, iA, iB, order_tot, is_top, need_rep;
int64_t blk_sz_A, blk_sz_B, vrt_sz_A, vrt_sz_B;
int nphys_dim;
int * idx_arr, * virt_dim, * phys_mapped;
int * virt_dim;
int * idx_arr;
int * virt_blk_len_A, * virt_blk_len_B;
int * blk_len_A, * blk_len_B;
tsum * htsum = NULL , ** rec_tsum = NULL;
mapping * map;
strp_tsr * str_A, * str_B;
tspsum * htsum = NULL , ** rec_tsum = NULL;
is_top = 1;
inv_idx(A->order, idx_A,
......@@ -420,14 +420,11 @@ namespace CTF_int {
nphys_dim = A->topo->order;
CTF_int::alloc_ptr(sizeof(int)*order_tot, (void**)&virt_dim);
CTF_int::alloc_ptr(sizeof(int)*A->order, (void**)&blk_len_A);
CTF_int::alloc_ptr(sizeof(int)*B->order, (void**)&blk_len_B);
CTF_int::alloc_ptr(sizeof(int)*A->order, (void**)&virt_blk_len_A);
CTF_int::alloc_ptr(sizeof(int)*B->order, (void**)&virt_blk_len_B);
CTF_int::alloc_ptr(sizeof(int)*order_tot, (void**)&virt_dim);
CTF_int::alloc_ptr(sizeof(int)*nphys_dim*2, (void**)&phys_mapped);
memset(phys_mapped, 0, sizeof(int)*nphys_dim*2);
/* Determine the block dimensions of each local subtensor */
blk_sz_A = A->size;
......@@ -437,29 +434,6 @@ namespace CTF_int {
calc_dim(B->order, blk_sz_B, B->pad_edge_len, B->edge_map,
&vrt_sz_B, virt_blk_len_B, blk_len_B);
/* Strip out the relevant part of the tensor if we are contracting over diagonal */
sA = strip_diag(A->order, order_tot, idx_A, vrt_sz_A,
A->edge_map, A->topo, A->sr,
blk_len_A, &blk_sz_A, &str_A);
sB = strip_diag(B->order, order_tot, idx_B, vrt_sz_B,
B->edge_map, B->topo, B->sr,
blk_len_B, &blk_sz_B, &str_B);
if (sA || sB){
if (A->wrld->cdt.rank == 0)
DPRINTF(1,"Stripping tensor\n");
strp_sum * ssum = new strp_sum(this);
ssum->sr_A = A->sr;
ssum->sr_B = B->sr;
htsum = ssum;
is_top = 0;
rec_tsum = &ssum->rec_tsum;
ssum->rec_strp_A = str_A;
ssum->rec_strp_B = str_B;
ssum->strip_A = sA;
ssum->strip_B = sB;
}
nvirt = 1;
for (i=0; i<order_tot; i++){
iA = idx_arr[2*i];
......@@ -469,7 +443,6 @@ namespace CTF_int {
while (map->has_child) map = map->child;
if (map->type == VIRTUAL_MAP){
virt_dim[i] = map->np;
if (sA) virt_dim[i] = virt_dim[i]/str_A->strip_dim[iA];
}
else virt_dim[i] = 1;
} else {
......@@ -478,79 +451,43 @@ namespace CTF_int {
while (map->has_child) map = map->child;
if (map->type == VIRTUAL_MAP){
virt_dim[i] = map->np;
if (sB) virt_dim[i] = virt_dim[i]/str_B->strip_dim[iA];
}
else virt_dim[i] = 1;
}
nvirt *= virt_dim[i];
}
for (i=0; i<A->order; i++){
map = &A->edge_map[i];
if (map->type == PHYSICAL_MAP){
phys_mapped[2*map->cdt+0] = 1;
}
while (map->has_child) {
map = map->child;
if (map->type == PHYSICAL_MAP){
phys_mapped[2*map->cdt+0] = 1;
}
}
}
for (i=0; i<B->order; i++){
map = &B->edge_map[i];
if (map->type == PHYSICAL_MAP){
phys_mapped[2*map->cdt+1] = 1;
}
while (map->has_child) {
map = map->child;
if (map->type == PHYSICAL_MAP){
phys_mapped[2*map->cdt+1] = 1;
}
}
}
bool need_perm = false;
if (A->is_sparse || B->is_sparse){
need_perm = true;
/* for (int i=0; i<A->order; i++){
if (idx_arr[2*idx_A[i]+1] != i) need_perm = true;
}
for (int i=0; i<B->order; i++){
if (idx_arr[2*idx_B[i]+1] != i) need_perm = true;
}*/
}
if (need_perm){
if (A->is_sparse){
tsum_sp_pin_keys * sksum = new tsum_sp_pin_keys(this, 1);
if (is_top){
htsum = sksum;
is_top = 0;
} else {
*rec_tsum = sksum;
}
rec_tsum = &sksum->rec_tsum;
tsum_sp_permute * pmsum = new tsum_sp_permute(this, 1, virt_blk_len_A);
*rec_tsum = pmsum;
rec_tsum = &pmsum->rec_tsum;
if (A->is_sparse){
tspsum_pin_keys * sksum = new tspsum_pin_keys(this, 1);
if (is_top){
htsum = sksum;
is_top = 0;
} else {
*rec_tsum = sksum;
}
if (B->is_sparse){
tsum_sp_pin_keys * sksum = new tsum_sp_pin_keys(this, 0);
if (is_top){
htsum = sksum;
is_top = 0;
} else {
*rec_tsum = sksum;
}
rec_tsum = &sksum->rec_tsum;
rec_tsum = &sksum->rec_tsum;
tsum_sp_permute * pmsum = new tsum_sp_permute(this, 0, virt_blk_len_B);
*rec_tsum = pmsum;
rec_tsum = &pmsum->rec_tsum;
tspsum_permute * pmsum = new tspsum_permute(this, 1, virt_blk_len_A);
*rec_tsum = pmsum;
rec_tsum = &pmsum->rec_tsum;
}
if (B->is_sparse){
tspsum_pin_keys * sksum = new tspsum_pin_keys(this, 0);
if (is_top){
htsum = sksum;
is_top = 0;
} else {
*rec_tsum = sksum;
}
rec_tsum = &sksum->rec_tsum;
tspsum_permute * pmsum = new tspsum_permute(this, 0, virt_blk_len_B);
*rec_tsum = pmsum;
rec_tsum = &pmsum->rec_tsum;
}
/* bool need_sp_map = false;
......@@ -565,7 +502,7 @@ namespace CTF_int {
}
if (need_sp_map){
tsum_sp_map * smsum = new tsum_sp_map(this);
tspsum_map * smsum = new tspsum_map(this);
if (is_top){
htsum = smsum;
is_top = 0;
......@@ -584,17 +521,13 @@ namespace CTF_int {
break;
}
}
if (need_rep){
if (A->wrld->cdt.rank == 0)
DPRINTF(1,"Replicating tensor\n");
tsum_replicate * rtsum = new tsum_replicate(this);
/* rtsum->sr_A = A->sr;
rtsum->sr_B = B->sr;
rtsum->is_sparse_A = A->is_sparse;
rtsum->is_sparse_B = B->is_sparse;
rtsum->nnz_A = A->nnz_loc;
rtsum->nnz_B = B->nnz_loc;*/
/* if (A->wrld->cdt.rank == 0)
DPRINTF(1,"Replicating tensor\n");*/
tspsum_replicate * rtsum = new tspsum_replicate(this, phys_mapped, blk_sz_A, blk_sz_B);
if (is_top){
htsum = rtsum;
is_top = 0;
......@@ -602,59 +535,162 @@ namespace CTF_int {
*rec_tsum = rtsum;
}
rec_tsum = &rtsum->rec_tsum;
rtsum->ncdt_A = 0;
rtsum->ncdt_B = 0;
rtsum->size_A = blk_sz_A;
rtsum->size_B = blk_sz_B;
rtsum->cdt_A = NULL;
rtsum->cdt_B = NULL;
for (i=0; i<nphys_dim; i++){
if (phys_mapped[2*i+0] == 0 && phys_mapped[2*i+1] == 1){
rtsum->ncdt_A++;
}
if (phys_mapped[2*i+1] == 0 && phys_mapped[2*i+0] == 1){
rtsum->ncdt_B++;
}
}
if (rtsum->ncdt_A > 0)
CTF_int::alloc_ptr(sizeof(CommData*)*rtsum->ncdt_A, (void**)&rtsum->cdt_A);
if (rtsum->ncdt_B > 0)
CTF_int::alloc_ptr(sizeof(CommData*)*rtsum->ncdt_B, (void**)&rtsum->cdt_B);
rtsum->ncdt_A = 0;
rtsum->ncdt_B = 0;
for (i=0; i<nphys_dim; i++){
if (phys_mapped[2*i+0] == 0 && phys_mapped[2*i+1] == 1){
rtsum->cdt_A[rtsum->ncdt_A] = &A->topo->dim_comm[i];
/* if (rtsum->cdt_A[rtsum->ncdt_A].alive == 0)
rtsum->cdt_A[rtsum->ncdt_A].activate(A->wrld->comm);*/
rtsum->ncdt_A++;
}
if (phys_mapped[2*i+1] == 0 && phys_mapped[2*i+0] == 1){
rtsum->cdt_B[rtsum->ncdt_B] = &B->topo->dim_comm[i];
/* if (rtsum->cdt_B[rtsum->ncdt_B].alive == 0)
rtsum->cdt_B[rtsum->ncdt_B].activate(B->wrld->comm);*/
rtsum->ncdt_B++;
}
}
ASSERT(rtsum->ncdt_A == 0 || rtsum->cdt_B == 0);
}
/* Multiply over virtual sub-blocks */
tspsum_virt * tsumv = new tspsum_virt(this);
if (is_top) {
htsum = tsumv;
is_top = 0;
} else {
*rec_tsum = tsumv;
}
rec_tsum = &tsumv->rec_tsum;
tsumv->num_dim = order_tot;
tsumv->virt_dim = virt_dim;
tsumv->blk_sz_A = vrt_sz_A;
tsumv->blk_sz_B = vrt_sz_B;
int * new_sym_A, * new_sym_B;
CTF_int::alloc_ptr(sizeof(int)*A->order, (void**)&new_sym_A);
memcpy(new_sym_A, A->sym, sizeof(int)*A->order);
CTF_int::alloc_ptr(sizeof(int)*B->order, (void**)&new_sym_B);
memcpy(new_sym_B, B->sym, sizeof(int)*B->order);
seq_tsr_spsum * tsumseq = new seq_tsr_spsum(this);
tsumseq->is_inner = 0;
tsumseq->edge_len_A = virt_blk_len_A;
tsumseq->sym_A = new_sym_A;
tsumseq->edge_len_B = virt_blk_len_B;
tsumseq->sym_B = new_sym_B;
tsumseq->is_custom = is_custom;
if (is_custom){
tsumseq->is_inner = 0;
tsumseq->func = func;
} else tsumseq->func = NULL;
if (is_top) {
htsum = tsumseq;
is_top = 0;
} else {
*rec_tsum = tsumseq;
}
CTF_int::cdealloc(idx_arr);
CTF_int::cdealloc(blk_len_A);
CTF_int::cdealloc(blk_len_B);
return htsum;
}
tsum * summation::construct_dense_sum(int inner_stride,
int const * phys_mapped){
int i, iA, iB, order_tot, is_top, sA, sB, need_rep, i_A, i_B, j, k;
int64_t blk_sz_A, blk_sz_B, vrt_sz_A, vrt_sz_B;
int nphys_dim, nvirt;
int * idx_arr, * virt_dim;
int * virt_blk_len_A, * virt_blk_len_B;
int * blk_len_A, * blk_len_B;
tsum * htsum = NULL , ** rec_tsum = NULL;
mapping * map;
strp_tsr * str_A, * str_B;
is_top = 1;
inv_idx(A->order, idx_A,
B->order, idx_B,
&order_tot, &idx_arr);
nphys_dim = A->topo->order;
CTF_int::alloc_ptr(sizeof(int)*order_tot, (void**)&virt_dim);
CTF_int::alloc_ptr(sizeof(int)*A->order, (void**)&blk_len_A);
CTF_int::alloc_ptr(sizeof(int)*B->order, (void**)&blk_len_B);
CTF_int::alloc_ptr(sizeof(int)*A->order, (void**)&virt_blk_len_A);
CTF_int::alloc_ptr(sizeof(int)*B->order, (void**)&virt_blk_len_B);
/* Determine the block dimensions of each local subtensor */
blk_sz_A = A->size;
blk_sz_B = B->size;
calc_dim(A->order, blk_sz_A, A->pad_edge_len, A->edge_map,
&vrt_sz_A, virt_blk_len_A, blk_len_A);
calc_dim(B->order, blk_sz_B, B->pad_edge_len, B->edge_map,
&vrt_sz_B, virt_blk_len_B, blk_len_B);
/* Strip out the relevant part of the tensor if we are contracting over diagonal */
sA = strip_diag(A->order, order_tot, idx_A, vrt_sz_A,
A->edge_map, A->topo, A->sr,
blk_len_A, &blk_sz_A, &str_A);
sB = strip_diag(B->order, order_tot, idx_B, vrt_sz_B,
B->edge_map, B->topo, B->sr,
blk_len_B, &blk_sz_B, &str_B);
if (sA || sB){
if (A->wrld->cdt.rank == 0)
DPRINTF(1,"Stripping tensor\n");
strp_sum * ssum = new strp_sum(this);
ssum->sr_A = A->sr;
ssum->sr_B = B->sr;
htsum = ssum;
is_top = 0;
rec_tsum = &ssum->rec_tsum;
ssum->rec_strp_A = str_A;
ssum->rec_strp_B = str_B;
ssum->strip_A = sA;
ssum->strip_B = sB;
}
nvirt = 1;
for (i=0; i<order_tot; i++){
iA = idx_arr[2*i];
iB = idx_arr[2*i+1];
if (iA != -1){
map = &A->edge_map[iA];
while (map->has_child) map = map->child;
if (map->type == VIRTUAL_MAP){
virt_dim[i] = map->np;
if (sA) virt_dim[i] = virt_dim[i]/str_A->strip_dim[iA];
}
else virt_dim[i] = 1;
} else {
ASSERT(iB!=-1);
map = &B->edge_map[iB];
while (map->has_child) map = map->child;
if (map->type == VIRTUAL_MAP){
virt_dim[i] = map->np;
if (sB) virt_dim[i] = virt_dim[i]/str_B->strip_dim[iA];
}
else virt_dim[i] = 1;
}
nvirt *= virt_dim[i];
}
need_rep = 0;
for (i=0; i<nphys_dim; i++){
if (phys_mapped[2*i+0] == 0 ||
phys_mapped[2*i+1] == 0){
need_rep = 1;
break;
}
}
if (need_rep){
/* if (A->wrld->cdt.rank == 0)
DPRINTF(1,"Replicating tensor\n");*/
tsum_replicate * rtsum = new tsum_replicate(this, phys_mapped, blk_sz_A, blk_sz_B);
if (is_top){
htsum = rtsum;
is_top = 0;
} else {
*rec_tsum = rtsum;
}
rec_tsum = &rtsum->rec_tsum;
}
/* Multiply over virtual sub-blocks */
if (nvirt > 1 || A->is_sparse || B->is_sparse){
if (nvirt > 1){
tsum_virt * tsumv = new tsum_virt(this);
/* tsumv->sr_A = A->sr;
tsumv->sr_B = B->sr;
tsumv->is_sparse_A = A->is_sparse;
tsumv->is_sparse_B = B->is_sparse;
tsumv->nnz_A = A->nnz_loc;
tsumv->nnz_B = B->nnz_loc;
tsumv->nnz_blk_B = B->nnz_blk;*/
if (is_top) {
htsum = tsumv;
is_top = 0;
......@@ -665,22 +701,16 @@ namespace CTF_int {
tsumv->num_dim = order_tot;
tsumv->virt_dim = virt_dim;
// tsumv->order_A = A->order;
tsumv->blk_sz_A = vrt_sz_A;
// tsumv->idx_map_A = idx_A;
// tsumv->order_B = B->order;
tsumv->blk_sz_B = vrt_sz_B;
// tsumv->idx_map_B = idx_B;
tsumv->buffer = NULL;
} else CTF_int::cdealloc(virt_dim);
int * new_sym_A, * new_sym_B;
CTF_int::alloc_ptr(sizeof(int)*A->order, (void**)&new_sym_A);
memcpy(new_sym_A, A->sym, sizeof(int)*A->order);
CTF_int::alloc_ptr(sizeof(int)*B->order, (void**)&new_sym_B);
memcpy(new_sym_B, B->sym, sizeof(int)*B->order);
seq_tsr_sum * tsumseq = new seq_tsr_sum(this);
/* tsumseq->sr_A = A->sr;
tsumseq->sr_B = B->sr;
tsumseq->is_sparse_A = A->is_sparse;
tsumseq->is_sparse_B = B->is_sparse;
tsumseq->nnz_A = A->nnz_loc;
tsumseq->nnz_B = B->nnz_loc;*/
if (inner_stride == -1){
tsumseq->is_inner = 0;
} else {
......@@ -728,18 +758,8 @@ namespace CTF_int {
}
}
}
if (is_top) {
htsum = tsumseq;
is_top = 0;
} else {
*rec_tsum = tsumseq;
}
// tsumseq->order_A = A->order;
// tsumseq->idx_map_A = idx_A;
tsumseq->edge_len_A = virt_blk_len_A;
tsumseq->sym_A = new_sym_A;
// tsumseq->order_B = B->order;
// tsumseq->idx_map_B = idx_B;
tsumseq->edge_len_B = virt_blk_len_B;
tsumseq->sym_B = new_sym_B;
tsumseq->is_custom = is_custom;
......@@ -747,15 +767,70 @@ namespace CTF_int {
tsumseq->is_inner = 0;
tsumseq->func = func;
} else tsumseq->func = NULL;
// htsum->alpha = alpha;
// htsum->beta = beta;
// htsum->A = A->data;
// htsum->B = B->data;
if (is_top) {
htsum = tsumseq;
is_top = 0;
} else {
*rec_tsum = tsumseq;
}
CTF_int::cdealloc(idx_arr);
CTF_int::cdealloc(blk_len_A);
CTF_int::cdealloc(blk_len_B);
return htsum;
}
tsum * summation::construct_sum(int inner_stride){
int i, iA, iB, order_tot, is_top, sA, sB, need_rep, i_A, i_B, j, k;
int64_t blk_sz_A, blk_sz_B, vrt_sz_A, vrt_sz_B;
int nphys_dim;
int * idx_arr, * phys_mapped;
int * virt_blk_len_A, * virt_blk_len_B;
int * blk_len_A, * blk_len_B;
tsum * htsum = NULL , ** rec_tsum = NULL;
mapping * map;
strp_tsr * str_A, * str_B;
is_top = 1;
nphys_dim = A->topo->order;
CTF_int::alloc_ptr(sizeof(int)*nphys_dim*2, (void**)&phys_mapped);
memset(phys_mapped, 0, sizeof(int)*nphys_dim*2);
for (i=0; i<A->order; i++){
map = &A->edge_map[i];
if (map->type == PHYSICAL_MAP){
phys_mapped[2*map->cdt+0] = 1;
}
while (map->has_child) {
map = map->child;
if (map->type == PHYSICAL_MAP){
phys_mapped[2*map->cdt+0] = 1;
}
}
}
for (i=0; i<B->order; i++){
map = &B->edge_map[i];
if (map->type == PHYSICAL_MAP){
phys_mapped[2*map->cdt+1] = 1;
}
while (map->has_child) {
map = map->child;
if (map->type == PHYSICAL_MAP){
phys_mapped[2*map->cdt+1] = 1;
}
}
}
if (A->is_sparse || B->is_sparse){
htsum = construct_sparse_sum(phys_mapped);
} else {
htsum = construct_dense_sum(inner_stride, phys_mapped);
}
CTF_int::cdealloc(phys_mapped);
return htsum;
......@@ -1432,17 +1507,18 @@ namespace CTF_int {
MPI_Barrier(tnsr_B->wrld->comm);
sumf->run();
if (tnsr_B->is_sparse){
if (tnsr_B->data != sumf->new_B){
tspsum * spsumf = (tspsum*)sumf;
if (tnsr_B->data != spsumf->new_B){
cdealloc(tnsr_B->data);
tnsr_B->data = sumf->new_B;
//tnsr_B->nnz_loc = sumf->new_nnz_B;
tnsr_B->data = spsumf->new_B;
//tnsr_B->nnz_loc = spsumf->new_nnz_B;
tnsr_B->nnz_loc = 0;
for (int i=0; i<tnsr_B->calc_nvirt(); i++){
// printf("rec %p pin %p new_blk_nnz_B[%d] = %ld\n",sumf->nnz_blk_B,tnsr_B->nnz_blk,i,tnsr_B->nnz_blk[i]);
// printf("rec %p pin %p new_blk_nnz_B[%d] = %ld\n",spsumf->nnz_blk_B,tnsr_B->nnz_blk,i,tnsr_B->nnz_blk[i]);
tnsr_B->nnz_loc += tnsr_B->nnz_blk[i];
}
}
ASSERT(tnsr_B->nnz_loc == sumf->new_nnz_B);
ASSERT(tnsr_B->nnz_loc == spsumf->new_nnz_B);
}
/*tnsr_B->unfold();
tnsr_B->print();
......
......@@ -3,6 +3,7 @@
#include "assert.h"
#include "sum_tsr.h"
#include "spsum_tsr.h"
namespace CTF_int {
class tensor;
......@@ -174,11 +175,29 @@ namespace CTF_int {
/**
* \brief constructs function pointer to sum tensors A and B,B = B*beta+alpha*A
* \param[in] inner_stride local daxpy stride
* \return tsum summation class pointer to run
*/
tsum * construct_sum(int inner_stride=-1);
/**
* \brief constructs function pointer to sum tensors A and B at least one of which is sparse,
* B = B*beta+alpha*A
* \param[in] virt_dim dimensions of grid of blocks owned by each process
* \param[in] phys_mapped dimension 2*num_phys_dim, keeps track of which dimensions A and B are mapped to
* \return tspsum summation class pointer to run
*/
tspsum * construct_sparse_sum(int const * phys_mapped);
/**
* \brief constructs function pointer to sum tensors A and B both of which are dense,
* B = B*beta+alpha*A
* \param[in] phys_mapped dimension 2*num_phys_dim, keeps track of which dimensions A and B are mapped to
* \return tsum summation class pointer to run
*/
tsum * construct_dense_sum(int inner_stride,
int const * phys_mapped);
/**
* \brief a*idx_map_A(A) + b*idx_map_B(B) -> idx_map_B(B).
* performs all necessary symmetric permutations removes/returns A/B to home buffer
......
......@@ -504,7 +504,7 @@ namespace CTF_int {
this->set_padding();
if (!is_sparse && this->size > INT_MAX && wrld->rank == 0)
printf("CTF WARNING: Tensor %s is has local size %ld, which is greater than INT_MAX=%ld, so MPI could run into problems\n", name, size, INT_MAX);
printf("CTF WARNING: Tensor %s is has local size %ld, which is greater than INT_MAX=%d, so MPI could run into problems\n", name, size, INT_MAX);
if (is_sparse){
nnz_blk = (int64_t*)alloc(sizeof(int64_t)*calc_nvirt());
......@@ -1798,7 +1798,7 @@ namespace CTF_int {
}
if (size > INT_MAX && wrld->cdt.rank == 0)
printf("CTF WARNING: Tensor %s is being redistributed to a mapping where its size is %ld, which is greater than INT_MAX=%ld, so MPI could run into problems\n", name, size, INT_MAX);
printf("CTF WARNING: Tensor %s is being redistributed to a mapping where its size is %ld, which is greater than INT_MAX=%d, so MPI could run into problems\n", name, size, INT_MAX);
#ifdef HOME_CONTRACT
if (this->is_home){
......@@ -1997,10 +1997,10 @@ namespace CTF_int {
}
char * pwdata = (char*)alloc(sr->pair_size()*nw);
PairIterator wdata(sr, pwdata);
nw=0;
#ifdef USE_OMP
#pragma omp parallel for
// #pragma omp parallel for
#endif
nw=0;
for (int p=0; p<nnz_loc; p++){
int64_t k = pi[p].k();
if ((k/lda_i)%lens[i] == (k/lda_j)%lens[j]){
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment