From c82aacd437b8b9fd21f13d8e791f5a3e34ab6665 Mon Sep 17 00:00:00 2001 From: sowle Date: Fri, 25 Dec 2020 23:33:01 +0300 Subject: [PATCH] experimental crypto: sc_invert2, performance comparison, point_G arithmetic fixed --- tests/functional_tests/crypto_tests.cpp | 372 ++++++++++-------------- 1 file changed, 155 insertions(+), 217 deletions(-) diff --git a/tests/functional_tests/crypto_tests.cpp b/tests/functional_tests/crypto_tests.cpp index 9dbb18df..0b04da05 100644 --- a/tests/functional_tests/crypto_tests.cpp +++ b/tests/functional_tests/crypto_tests.cpp @@ -57,210 +57,80 @@ void sc_exp(unsigned char* out, const unsigned char* z, const unsigned char* s) } } -// out = z ^ -1 (= z ^ (L - 2) according to Fermat little theorem) -void sc_invert(unsigned char* out, const unsigned char* z) +/* + Input: + s[0]+256*s[1]+...+256^31*s[31] = s + a[0]+256*a[1]+...+256^31*a[31] = a + n + * + Output: + s[0]+256*s[1]+...+256^31*s[31] = a * s^(2^n) mod l + where l = 2^252 + 27742317777372353535851937790883648493. + Overwrites s in place. + */ + +static inline void +sc_sqmul(unsigned char s[32], const int n, const unsigned char a[32]) { - memcpy(out, z, sizeof(crypto::ec_scalar)); - for (size_t i = 0; i < 128; ++i) - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, out); - sc_mul(out, out, z); - sc_mul(out, out, out); - sc_mul(out, out, z); + int i; + for (i = 0; i < n; ++i) + sc_mul(s, s, s); + sc_mul(s, s, a); } +void sc_invert2(unsigned char* recip, const unsigned char* s) +{ + unsigned char _10[32], _100[32], _1000[32], _10000[32], _100000[32], + _1000000[32], _10010011[32], _10010111[32], _100110[32], _1010[32], + _1010000[32], _1010011[32], _1011[32], _10110[32], _10111101[32], + _11[32], _1100011[32], _1100111[32], _11010011[32], _1101011[32], + _11100111[32], _11101011[32], _11110101[32]; + + sc_mul(_10, s, s); + sc_mul(_11, s, _10); + sc_mul(_100, s, _11); + sc_mul(_1000, _100, _100); + sc_mul(_1010, _10, _1000); + sc_mul(_1011, s, _1010); + sc_mul(_10000, _1000, _1000); + sc_mul(_10110, _1011, _1011); + sc_mul(_100000, _1010, _10110); + sc_mul(_100110, _10000, _10110); + sc_mul(_1000000, _100000, _100000); + sc_mul(_1010000, _10000, _1000000); + sc_mul(_1010011, _11, _1010000); + sc_mul(_1100011, _10000, _1010011); + sc_mul(_1100111, _100, _1100011); + sc_mul(_1101011, _100, _1100111); + sc_mul(_10010011, _1000000, _1010011); + sc_mul(_10010111, _100, _10010011); + sc_mul(_10111101, _100110, _10010111); + sc_mul(_11010011, _10110, _10111101); + sc_mul(_11100111, _1010000, _10010111); + sc_mul(_11101011, _100, _11100111); + sc_mul(_11110101, _1010, _11101011); + + sc_mul(recip, _1011, _11110101); + + sc_sqmul(recip, 126, _1010011); + + sc_sqmul(recip, 9, _10); + sc_mul(recip, recip, _11110101); + sc_sqmul(recip, 7, _1100111); + sc_sqmul(recip, 9, _11110101); + sc_sqmul(recip, 11, _10111101); + sc_sqmul(recip, 8, _11100111); + sc_sqmul(recip, 9, _1101011); + sc_sqmul(recip, 6, _1011); + sc_sqmul(recip, 14, _10010011); + sc_sqmul(recip, 10, _1100011); + sc_sqmul(recip, 9, _10010111); + sc_sqmul(recip, 10, _11110101); + sc_sqmul(recip, 8, _11010011); + sc_sqmul(recip, 8, _11101011); +} + + // // Helpers @@ -327,16 +197,14 @@ struct scalar_t m_u64[3] = a3; } - scalar_t(int64_t v) + scalar_t(uint64_t v) { zero(); if (v == 0) { return; } - //unsigned char bytes[32] = {0}; - reinterpret_cast(m_s) = v; - //fe_frombytes(m_fe, bytes); + reinterpret_cast(m_s) = v; // do not need to call reduce as 2^64 < L } @@ -497,6 +365,7 @@ struct scalar_t }; // struct scalar_t + //__declspec(align(32)) struct point_t { @@ -518,14 +387,14 @@ struct point_t return ge_frombytes_vartime(&m_p3, reinterpret_cast(&pk)) == 0; } - operator crypto::public_key() + operator crypto::public_key() const { crypto::public_key result; ge_p3_tobytes((unsigned char*)&result, &m_p3); return result; } - point_t operator+(const point_t& rhs) + point_t operator+(const point_t& rhs) const { point_t result; ge_cached rhs_c; @@ -536,7 +405,7 @@ struct point_t return result; } - point_t operator-(const point_t& rhs) + point_t operator-(const point_t& rhs) const { point_t result; ge_cached rhs_c; @@ -554,6 +423,15 @@ struct point_t return result; } + friend point_t operator/(const point_t& lhs, const scalar_t& rhs) + { + point_t result; + scalar_t reciprocal; + sc_invert(&reciprocal.m_s[0], &rhs.m_s[0]); + ge_scalarmult_p3(&result.m_p3, &reciprocal.m_s[0], &lhs.m_p3); + return result; + } + friend bool operator==(const point_t& lhs, const point_t& rhs) { // convert to xy form, then compare components (because (z, y, z, t) representation is not unique) @@ -581,20 +459,28 @@ struct point_g_t : public point_t { point_g_t() { - + scalar_t one(1); + ge_scalarmult_base(&m_p3, &one.m_s[0]); } friend point_t operator*(const scalar_t& lhs, const point_g_t&) { point_t result; - ge_scalarmult_base(&result.m_p3, reinterpret_cast(&lhs)); + ge_scalarmult_base(&result.m_p3, &lhs.m_s[0]); return result; } - /*friend point_t operator*(const int64_t lhs, const point_g_t& rhs) + friend point_t operator/(const point_g_t&, const scalar_t& rhs) { - return operator*(scalar_t) - }*/ + point_t result; + scalar_t reciprocal; + sc_invert(&reciprocal.m_s[0], &rhs.m_s[0]); + ge_scalarmult_base(&result.m_p3, &reciprocal.m_s[0]); + return result; + } + + + static_assert(sizeof(crypto::public_key) == 32, "size error"); @@ -839,6 +725,49 @@ TEST(crypto, sc_mul_performance) return true; } +TEST(crypto, sc_invert_performance) +{ + std::vector scalars(10000); + LOG_PRINT_L0("Running " << scalars.size() << " sc_invert tests..."); + + for (auto& s : scalars) + { + s.make_random(); + scalar_t a, b; + sc_invert(&a.m_s[0], &s.m_s[0]); + sc_invert2(&b.m_s[0], &s.m_s[0]); + ASSERT_EQ(a, b); + } + + std::vector results_0(scalars.size()); + std::vector results_1(scalars.size()); + std::vector results_2(scalars.size()); + + for (size_t j = 0; j < 10; ++j) + { + LOG_PRINT_L0("Run #" << j); + + // warm-up + for (size_t i = 0; i < scalars.size(); ++i) + sc_invert(&results_0[i].m_s[0], &scalars[i].m_s[0]); + + TIME_MEASURE_START(t_1); + for (size_t i = 0; i < scalars.size(); ++i) + sc_invert(&results_1[i].m_s[0], &scalars[i].m_s[0]); + TIME_MEASURE_FINISH(t_1); + + TIME_MEASURE_START(t_2); + for (size_t i = 0; i < scalars.size(); ++i) + sc_invert(&results_2[i].m_s[0], &scalars[i].m_s[0]); + TIME_MEASURE_FINISH(t_2); + + LOG_PRINT_L0("sc_invert: " << std::fixed << std::setprecision(3) << 1.0 * t_1 / scalars.size() << " mcs " << (t_1 < t_2 ? "WIN" : "")); + LOG_PRINT_L0("sc_invert2: " << std::fixed << std::setprecision(3) << 1.0 * t_2 / scalars.size() << " mcs " << (t_1 < t_2 ? "" : " WIN")); + } + + return true; +} + TEST(crypto, scalar_arithmetic_assignment) { std::vector scalars(1000); @@ -871,10 +800,19 @@ TEST(crypto, point_basics) point_t K = 193847 * point_G; point_t C = E + K; - ASSERT_TRUE(X == 16 * point_G); - ASSERT_TRUE(C - K == E); - ASSERT_TRUE(C - E == K); - ASSERT_TRUE(C == 193851 * point_G); + ASSERT_EQ(X, 16 * point_G); + ASSERT_EQ(C - K, E); + ASSERT_EQ(C - E, K); + ASSERT_EQ(C, (193847 + 4) * point_G); + + ASSERT_EQ(point_G / 1, 1 * point_G); + ASSERT_EQ(C / 3, E / 3 + K / 3); + //ASSERT_EQ(K, 61 * (K / (61))); + //ASSERT_EQ(K, 192847 * (K / scalar_t(192847))); + ASSERT_EQ(K, 61 * (283 * (192847 * (K / (192847ull * 283 * 61))))); + + ASSERT_EQ(E, point_G + point_G + point_G + point_G); + ASSERT_EQ(E - point_G, 3 * point_G); return true; }