crypto: scalar_vec_t, scalar_mat_t
This commit is contained in:
parent
5b1fa3d5e8
commit
2fdeefb545
1 changed files with 147 additions and 0 deletions
|
|
@ -717,6 +717,146 @@ namespace crypto
|
|||
}; // struct point_g_t
|
||||
|
||||
|
||||
//
|
||||
// vector of scalars
|
||||
//
|
||||
struct scalar_vec_t : public std::vector<scalar_t>
|
||||
{
|
||||
typedef std::vector<scalar_t> super_t;
|
||||
|
||||
scalar_vec_t() {}
|
||||
scalar_vec_t(size_t n) : super_t(n) {}
|
||||
scalar_vec_t(std::initializer_list<scalar_t> init_list) : super_t(init_list) {}
|
||||
|
||||
bool is_reduced() const
|
||||
{
|
||||
for (auto& el : *this)
|
||||
if (!el.is_reduced())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// add a scalar rhs to each element
|
||||
scalar_vec_t operator+(const scalar_t& rhs) const
|
||||
{
|
||||
scalar_vec_t result(size());
|
||||
for (size_t i = 0, n = size(); i < n; ++i)
|
||||
result[i] = at(i) + rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
// subtract a scalar rhs to each element
|
||||
scalar_vec_t operator-(const scalar_t& rhs) const
|
||||
{
|
||||
scalar_vec_t result(size());
|
||||
for (size_t i = 0, n = size(); i < n; ++i)
|
||||
result[i] = at(i) - rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
// multiply each element of the vector by a scalar
|
||||
scalar_vec_t operator*(const scalar_t& rhs) const
|
||||
{
|
||||
scalar_vec_t result(size());
|
||||
for (size_t i = 0, n = size(); i < n; ++i)
|
||||
result[i] = at(i) * rhs;
|
||||
return result;
|
||||
}
|
||||
|
||||
// component-wise multiplication (a.k.a the Hadamard product) (only if their sizes match)
|
||||
scalar_vec_t operator*(const scalar_vec_t& rhs) const
|
||||
{
|
||||
scalar_vec_t result;
|
||||
const size_t n = size();
|
||||
if (n != rhs.size())
|
||||
return result;
|
||||
|
||||
result.resize(size());
|
||||
for (size_t i = 0; i < n; ++i)
|
||||
result[i] = at(i) * rhs[i];
|
||||
return result;
|
||||
}
|
||||
|
||||
// add each element of two vectors, but only if their sizes match
|
||||
scalar_vec_t operator+(const scalar_vec_t& rhs) const
|
||||
{
|
||||
scalar_vec_t result;
|
||||
const size_t n = size();
|
||||
if (n != rhs.size())
|
||||
return result;
|
||||
|
||||
result.resize(size());
|
||||
for (size_t i = 0; i < n; ++i)
|
||||
result[i] = at(i) + rhs[i];
|
||||
return result;
|
||||
}
|
||||
|
||||
// zeroes all elements
|
||||
void zero()
|
||||
{
|
||||
size_t size_bytes = sizeof(scalar_t) * size();
|
||||
memset(data(), 0, size_bytes);
|
||||
}
|
||||
|
||||
// invert all elements in-place efficiently: 4*N muptiplications + 1 inversion
|
||||
void invert()
|
||||
{
|
||||
// muls muls_rev
|
||||
// 0: 1 2 3 .. n-1
|
||||
// 1: 0 2 3 .. n-1
|
||||
// 2: 0 1 3 .. n-1
|
||||
//
|
||||
// n-1: 0 1 2 3 .. n-2
|
||||
|
||||
const size_t size = this->size();
|
||||
|
||||
if (size < 2)
|
||||
{
|
||||
if (size == 1)
|
||||
at(0) = at(0).reciprocal();
|
||||
return;
|
||||
}
|
||||
|
||||
scalar_vec_t muls(size), muls_rev(size);
|
||||
muls[0] = 1;
|
||||
for (size_t i = 0; i < size - 1; ++i)
|
||||
muls[i + 1] = at(i) * muls[i];
|
||||
|
||||
muls_rev[size - 1] = 1;
|
||||
for (size_t i = size - 1; i != 0; --i)
|
||||
muls_rev[i - 1] = at(i) * muls_rev[i];
|
||||
|
||||
scalar_t inv = (muls[size - 1] * at(size - 1)).reciprocal();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
at(i) = muls[i] * inv * muls_rev[i];
|
||||
}
|
||||
|
||||
scalar_t calc_hs() const;
|
||||
|
||||
}; // scalar_vec_t
|
||||
|
||||
|
||||
// treats vector of scalars as an M x N matrix just for convenience
|
||||
template<size_t N>
|
||||
struct scalar_mat_t : public scalar_vec_t
|
||||
{
|
||||
typedef scalar_vec_t super_t;
|
||||
static_assert(N > 0, "invalid N value");
|
||||
|
||||
scalar_mat_t() {}
|
||||
scalar_mat_t(size_t n) : super_t(n) {}
|
||||
scalar_mat_t(std::initializer_list<scalar_t> init_list) : super_t(init_list) {}
|
||||
|
||||
// matrix accessor M rows x N cols
|
||||
scalar_t& operator()(size_t row, size_t col)
|
||||
{
|
||||
return at(row * N + col);
|
||||
}
|
||||
}; // scalar_mat_t
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Global constants
|
||||
//
|
||||
|
|
@ -908,4 +1048,11 @@ namespace crypto
|
|||
}; // hash_helper_t struct
|
||||
|
||||
|
||||
inline scalar_t scalar_vec_t::calc_hs() const
|
||||
{
|
||||
// hs won't touch memory if size is 0, so it's safe
|
||||
return hash_helper_t::hs(data(), sizeof(scalar_t) * size());
|
||||
}
|
||||
|
||||
|
||||
} // namespace crypto
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue