From 2fdeefb545f034b290b366b20c8335389d3b270b Mon Sep 17 00:00:00 2001 From: sowle Date: Sat, 5 Jun 2021 03:36:58 +0300 Subject: [PATCH] crypto: scalar_vec_t, scalar_mat_t --- src/crypto/crypto-sugar.h | 147 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/src/crypto/crypto-sugar.h b/src/crypto/crypto-sugar.h index 30abccb9..1c1b1b9a 100644 --- a/src/crypto/crypto-sugar.h +++ b/src/crypto/crypto-sugar.h @@ -717,6 +717,146 @@ namespace crypto }; // struct point_g_t + // + // vector of scalars + // + struct scalar_vec_t : public std::vector + { + typedef std::vector super_t; + + scalar_vec_t() {} + scalar_vec_t(size_t n) : super_t(n) {} + scalar_vec_t(std::initializer_list init_list) : super_t(init_list) {} + + bool is_reduced() const + { + for (auto& el : *this) + if (!el.is_reduced()) + return false; + return true; + } + + // add a scalar rhs to each element + scalar_vec_t operator+(const scalar_t& rhs) const + { + scalar_vec_t result(size()); + for (size_t i = 0, n = size(); i < n; ++i) + result[i] = at(i) + rhs; + return result; + } + + // subtract a scalar rhs to each element + scalar_vec_t operator-(const scalar_t& rhs) const + { + scalar_vec_t result(size()); + for (size_t i = 0, n = size(); i < n; ++i) + result[i] = at(i) - rhs; + return result; + } + + // multiply each element of the vector by a scalar + scalar_vec_t operator*(const scalar_t& rhs) const + { + scalar_vec_t result(size()); + for (size_t i = 0, n = size(); i < n; ++i) + result[i] = at(i) * rhs; + return result; + } + + // component-wise multiplication (a.k.a the Hadamard product) (only if their sizes match) + scalar_vec_t operator*(const scalar_vec_t& rhs) const + { + scalar_vec_t result; + const size_t n = size(); + if (n != rhs.size()) + return result; + + result.resize(size()); + for (size_t i = 0; i < n; ++i) + result[i] = at(i) * rhs[i]; + return result; + } + + // add each element of two vectors, but only if their sizes match + scalar_vec_t operator+(const scalar_vec_t& rhs) const + { + scalar_vec_t result; + const size_t n = size(); + if (n != rhs.size()) + return result; + + result.resize(size()); + for (size_t i = 0; i < n; ++i) + result[i] = at(i) + rhs[i]; + return result; + } + + // zeroes all elements + void zero() + { + size_t size_bytes = sizeof(scalar_t) * size(); + memset(data(), 0, size_bytes); + } + + // invert all elements in-place efficiently: 4*N muptiplications + 1 inversion + void invert() + { + // muls muls_rev + // 0: 1 2 3 .. n-1 + // 1: 0 2 3 .. n-1 + // 2: 0 1 3 .. n-1 + // + // n-1: 0 1 2 3 .. n-2 + + const size_t size = this->size(); + + if (size < 2) + { + if (size == 1) + at(0) = at(0).reciprocal(); + return; + } + + scalar_vec_t muls(size), muls_rev(size); + muls[0] = 1; + for (size_t i = 0; i < size - 1; ++i) + muls[i + 1] = at(i) * muls[i]; + + muls_rev[size - 1] = 1; + for (size_t i = size - 1; i != 0; --i) + muls_rev[i - 1] = at(i) * muls_rev[i]; + + scalar_t inv = (muls[size - 1] * at(size - 1)).reciprocal(); + + for (size_t i = 0; i < size; ++i) + at(i) = muls[i] * inv * muls_rev[i]; + } + + scalar_t calc_hs() const; + + }; // scalar_vec_t + + + // treats vector of scalars as an M x N matrix just for convenience + template + struct scalar_mat_t : public scalar_vec_t + { + typedef scalar_vec_t super_t; + static_assert(N > 0, "invalid N value"); + + scalar_mat_t() {} + scalar_mat_t(size_t n) : super_t(n) {} + scalar_mat_t(std::initializer_list init_list) : super_t(init_list) {} + + // matrix accessor M rows x N cols + scalar_t& operator()(size_t row, size_t col) + { + return at(row * N + col); + } + }; // scalar_mat_t + + + // // Global constants // @@ -908,4 +1048,11 @@ namespace crypto }; // hash_helper_t struct + inline scalar_t scalar_vec_t::calc_hs() const + { + // hs won't touch memory if size is 0, so it's safe + return hash_helper_t::hs(data(), sizeof(scalar_t) * size()); + } + + } // namespace crypto