crypto: scalar_vec_t, scalar_mat_t

This commit is contained in:
sowle 2021-06-05 03:36:58 +03:00
parent 5b1fa3d5e8
commit 2fdeefb545
No known key found for this signature in database
GPG key ID: C07A24B2D89D49FC

View file

@ -717,6 +717,146 @@ namespace crypto
}; // struct point_g_t
//
// vector of scalars
//
struct scalar_vec_t : public std::vector<scalar_t>
{
typedef std::vector<scalar_t> super_t;
scalar_vec_t() {}
scalar_vec_t(size_t n) : super_t(n) {}
scalar_vec_t(std::initializer_list<scalar_t> init_list) : super_t(init_list) {}
bool is_reduced() const
{
for (auto& el : *this)
if (!el.is_reduced())
return false;
return true;
}
// add a scalar rhs to each element
scalar_vec_t operator+(const scalar_t& rhs) const
{
scalar_vec_t result(size());
for (size_t i = 0, n = size(); i < n; ++i)
result[i] = at(i) + rhs;
return result;
}
// subtract a scalar rhs to each element
scalar_vec_t operator-(const scalar_t& rhs) const
{
scalar_vec_t result(size());
for (size_t i = 0, n = size(); i < n; ++i)
result[i] = at(i) - rhs;
return result;
}
// multiply each element of the vector by a scalar
scalar_vec_t operator*(const scalar_t& rhs) const
{
scalar_vec_t result(size());
for (size_t i = 0, n = size(); i < n; ++i)
result[i] = at(i) * rhs;
return result;
}
// component-wise multiplication (a.k.a the Hadamard product) (only if their sizes match)
scalar_vec_t operator*(const scalar_vec_t& rhs) const
{
scalar_vec_t result;
const size_t n = size();
if (n != rhs.size())
return result;
result.resize(size());
for (size_t i = 0; i < n; ++i)
result[i] = at(i) * rhs[i];
return result;
}
// add each element of two vectors, but only if their sizes match
scalar_vec_t operator+(const scalar_vec_t& rhs) const
{
scalar_vec_t result;
const size_t n = size();
if (n != rhs.size())
return result;
result.resize(size());
for (size_t i = 0; i < n; ++i)
result[i] = at(i) + rhs[i];
return result;
}
// zeroes all elements
void zero()
{
size_t size_bytes = sizeof(scalar_t) * size();
memset(data(), 0, size_bytes);
}
// invert all elements in-place efficiently: 4*N muptiplications + 1 inversion
void invert()
{
// muls muls_rev
// 0: 1 2 3 .. n-1
// 1: 0 2 3 .. n-1
// 2: 0 1 3 .. n-1
//
// n-1: 0 1 2 3 .. n-2
const size_t size = this->size();
if (size < 2)
{
if (size == 1)
at(0) = at(0).reciprocal();
return;
}
scalar_vec_t muls(size), muls_rev(size);
muls[0] = 1;
for (size_t i = 0; i < size - 1; ++i)
muls[i + 1] = at(i) * muls[i];
muls_rev[size - 1] = 1;
for (size_t i = size - 1; i != 0; --i)
muls_rev[i - 1] = at(i) * muls_rev[i];
scalar_t inv = (muls[size - 1] * at(size - 1)).reciprocal();
for (size_t i = 0; i < size; ++i)
at(i) = muls[i] * inv * muls_rev[i];
}
scalar_t calc_hs() const;
}; // scalar_vec_t
// treats vector of scalars as an M x N matrix just for convenience
template<size_t N>
struct scalar_mat_t : public scalar_vec_t
{
typedef scalar_vec_t super_t;
static_assert(N > 0, "invalid N value");
scalar_mat_t() {}
scalar_mat_t(size_t n) : super_t(n) {}
scalar_mat_t(std::initializer_list<scalar_t> init_list) : super_t(init_list) {}
// matrix accessor M rows x N cols
scalar_t& operator()(size_t row, size_t col)
{
return at(row * N + col);
}
}; // scalar_mat_t
//
// Global constants
//
@ -908,4 +1048,11 @@ namespace crypto
}; // hash_helper_t struct
inline scalar_t scalar_vec_t::calc_hs() const
{
// hs won't touch memory if size is 0, so it's safe
return hash_helper_t::hs(data(), sizeof(scalar_t) * size());
}
} // namespace crypto