provisional distributed spmv
This commit is contained in:
76
csr_mat.hpp
76
csr_mat.hpp
@@ -6,6 +6,8 @@
|
||||
#include "coo_mat.hpp"
|
||||
#include "algorithm.hpp"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
template <Where where>
|
||||
class CsrMat {
|
||||
public:
|
||||
@@ -161,6 +163,16 @@ public:
|
||||
CsrMat(CsrMat &&other) = delete;
|
||||
CsrMat(const CsrMat &other) = delete;
|
||||
|
||||
CsrMat &operator=(CsrMat &&rhs) {
|
||||
if (this != &rhs) {
|
||||
rowPtr_ = std::move(rhs.rowPtr_);
|
||||
colInd_ = std::move(rhs.colInd_);
|
||||
val_ = std::move(rhs.val_);
|
||||
numCols_ = std::move(rhs.numCols_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// create device matrix from host
|
||||
CsrMat(const CsrMat<Where::host> &m) :
|
||||
rowPtr_(m.rowPtr_), colInd_(m.colInd_), val_(m.val_), numCols_(m.numCols_) {
|
||||
@@ -193,4 +205,68 @@ public:
|
||||
return v;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
// mxn random matrix with nnz
|
||||
CsrMat<Where::host> random_matrix(const int64_t m, const int64_t n, const int64_t nnz) {
|
||||
|
||||
if (m * n < nnz) {
|
||||
throw std::logic_error(AT);
|
||||
}
|
||||
|
||||
CooMat coo(m,n);
|
||||
while(coo.nnz() < nnz) {
|
||||
|
||||
int64_t toPush = nnz - coo.nnz();
|
||||
std::cerr << "adding " << toPush << " non-zeros\n";
|
||||
for (int64_t _ = 0; _ < toPush; ++_) {
|
||||
int r = rand() % m;
|
||||
int c = rand() % n;
|
||||
float e = 1.0;
|
||||
coo.push_back(r, c, e);
|
||||
}
|
||||
std::cerr << "removing duplicate non-zeros\n";
|
||||
coo.remove_duplicates();
|
||||
}
|
||||
coo.sort();
|
||||
std::cerr << "coo: " << coo.num_rows() << "x" << coo.num_cols() << "\n";
|
||||
CsrMat<Where::host> csr(coo);
|
||||
std::cerr << "csr: " << csr.num_rows() << "x" << csr.num_cols() << " w/ " << csr.nnz() << "\n";
|
||||
return csr;
|
||||
};
|
||||
|
||||
// nxn diagonal matrix with bandwidth b
|
||||
CsrMat<Where::host> random_band_matrix(const int64_t n, const int64_t bw, const int64_t nnz) {
|
||||
|
||||
CooMat coo(n,n);
|
||||
while(coo.nnz() < nnz) {
|
||||
|
||||
int64_t toPush = nnz - coo.nnz();
|
||||
std::cerr << "adding " << toPush << " non-zeros\n";
|
||||
for (int64_t _ = 0; _ < toPush; ++_) {
|
||||
int r = rand() % n; // random row
|
||||
|
||||
// column in the band
|
||||
int lb = r - bw;
|
||||
int ub = r + bw + 1;
|
||||
int64_t c = rand() % (ub - lb) + lb;
|
||||
if (c < 0 || c >= n) {
|
||||
// retry, don't over-weight first or last column
|
||||
continue;
|
||||
}
|
||||
float e = 1.0;
|
||||
|
||||
assert(c < n);
|
||||
assert(r < n);
|
||||
coo.push_back(r, c, e);
|
||||
}
|
||||
std::cerr << "removing duplicate non-zeros\n";
|
||||
coo.remove_duplicates();
|
||||
}
|
||||
coo.sort();
|
||||
std::cerr << "coo: " << coo.num_rows() << "x" << coo.num_cols() << "\n";
|
||||
CsrMat<Where::host> csr(coo);
|
||||
std::cerr << "csr: " << csr.num_rows() << "x" << csr.num_cols() << " w/ " << csr.nnz() << "\n";
|
||||
return csr;
|
||||
};
|
Reference in New Issue
Block a user