Compressed Sparse Blocks  1.2
Semirings.h
Go to the documentation of this file.
1
2 #ifndef _SEMIRINGS_H_
3 #define _SEMIRINGS_H_
4
5 #include <utility>
6 #include <climits>
7 #include <cmath>
8 #include <tr1/array>
9 #include "promote.h"
10
11 template <typename T>
12 struct inf_plus{
13  T operator()(const T& a, const T& b) const {
14  T inf = std::numeric_limits<T>::max();
15  if (a == inf || b == inf){
16  return inf;
17  }
18  return a + b;
19  }
20 };
21
22 // (+,*) on scalars
23 template <class T1, class T2>
24 struct PTSR
25 {
27
28  static T_promote add(const T1 & arg1, const T2 & arg2)
29  {
30  return (static_cast<T_promote>(arg1) +
31  static_cast<T_promote>(arg2) );
32  }
33  static T_promote multiply(const T1 & arg1, const T2 & arg2)
34  {
35  return (static_cast<T_promote>(arg1) *
36  static_cast<T_promote>(arg2) );
37  }
38  // y += ax overload with a=1
39  static void axpy(const T2 & x, T_promote & y)
40  {
41  y += x;
42  }
43
44  static void axpy(T1 a, const T2 & x, T_promote & y)
45  {
46  y += a*x;
47  }
48 };
49
50
51 template<int Begin, int End, int Step>
52 struct UnrollerL {
53  template<typename Lambda>
54  static void step(Lambda& func) {
55  func(Begin);
57  }
58 };
59
60 template<int End, int Step>
61 struct UnrollerL<End, End, Step> {
62  template<typename Lambda>
63  static void step(Lambda& func) {
64  // base case is when Begin=End; do nothing
65  }
66 };
67
68
69 // (+,*) on std:array's
70 template<class T1, class T2, unsigned D>
71 struct PTSRArray
72 {
74
75  // y <- a*x + y overload with a=1
76  static void axpy(const array<T2, D> & b, array<T_promote, D> & c)
77  {
78  const T2 * __restrict barr = b.data();
79  T_promote * __restrict carr = c.data();
80  __assume_aligned(barr, ALIGN);
81  __assume_aligned(carr, ALIGN);
82
83  #pragma simd
84  for(int i=0; i<D; ++i)
85  {
86  carr[i] += barr[i];
87  }
88  // auto multadd = [&] (int i) { c[i] += b[i]; };
89  // UnrollerL<0, D, 1>::step ( multadd );
90  }
91
92  // Todo: Do partial unrolling; this code will bloat for D > 32
93  static void axpy(T1 a, const array<T2,D> & b, array<T_promote,D> & c)
94  {
95  const T2 * __restrict barr = b.data();
96  T_promote * __restrict carr = c.data();
97  __assume_aligned(barr, ALIGN);
98  __assume_aligned(carr, ALIGN);
99
100  #pragma simd
101  for(int i=0; i<D; ++i)
102  {
103  carr[i] += a* barr[i];
104  }
105  //auto multadd = [&] (int i) { carr[i] += a* barr[i]; };
106  //UnrollerL<0, D, 1>::step ( multadd );
107  }
108 };
109
110 // (min,+) on scalars
111 template <class T1, class T2>
112 struct MPSR
113 {
115
116  static T_promote add(const T1 & arg1, const T2 & arg2)
117  {
118  return std::min<T_promote>
119  (static_cast<T_promote>(arg1), static_cast<T_promote>(arg2));
120  }
121  static T_promote multiply(const T1 & arg1, const T2 & arg2)
122  {
123  return inf_plus< T_promote >
124  (static_cast<T_promote>(arg1), static_cast<T_promote>(arg2));
125  }
126 };
127
128
129 #endif
static void axpy(T1 a, const array< T2, D > &b, array< T_promote, D > &c)
Definition: Semirings.h:93
static void step(Lambda &func)
Definition: Semirings.h:63
promote_trait< T1, T2 >::T_promote T_promote
Definition: Semirings.h:26
static void step(Lambda &func)
Definition: Semirings.h:54
Definition: Semirings.h:24
#define ALIGN
Definition: spmm_test.cpp:26
static void axpy(T1 a, const T2 &x, T_promote &y)
Definition: Semirings.h:44
static void axpy(const array< T2, D > &b, array< T_promote, D > &c)
Definition: Semirings.h:76
T operator()(const T &a, const T &b) const
Definition: Semirings.h:13
static T_promote add(const T1 &arg1, const T2 &arg2)
Definition: Semirings.h:28
promote_trait< T1, T2 >::T_promote T_promote
Definition: Semirings.h:114
promote_trait< T1, T2 >::T_promote T_promote
Definition: Semirings.h:73
static void axpy(const T2 &x, T_promote &y)
Definition: Semirings.h:39
static T_promote multiply(const T1 &arg1, const T2 &arg2)
Definition: Semirings.h:33
static T_promote add(const T1 &arg1, const T2 &arg2)
Definition: Semirings.h:116
static T_promote multiply(const T1 &arg1, const T2 &arg2)
Definition: Semirings.h:121