diff --git a/tests/functional_tests/crypto_tests.cpp b/tests/functional_tests/crypto_tests.cpp index ae79d65c..9dbb18df 100644 --- a/tests/functional_tests/crypto_tests.cpp +++ b/tests/functional_tests/crypto_tests.cpp @@ -390,24 +390,42 @@ struct scalar_t scalar_t operator+(const scalar_t& v) const { scalar_t result; - sc_add(reinterpret_cast(&result), reinterpret_cast(&m_s), reinterpret_cast(&v)); + sc_add(&result.m_s[0], &m_s[0], &v.m_s[0]); return result; } + scalar_t& operator+=(const scalar_t& v) + { + sc_add(&m_s[0], &m_s[0], &v.m_s[0]); + return *this; + } + scalar_t operator-(const scalar_t& v) const { scalar_t result; - sc_sub(reinterpret_cast(&result), reinterpret_cast(&m_s), reinterpret_cast(&v)); + sc_sub(&result.m_s[0], &m_s[0], &v.m_s[0]); return result; } + scalar_t& operator-=(const scalar_t& v) + { + sc_sub(&m_s[0], &m_s[0], &v.m_s[0]); + return *this; + } + scalar_t operator*(const scalar_t& v) const { scalar_t result; - sc_mul(reinterpret_cast(&result), reinterpret_cast(&m_s), reinterpret_cast(&v)); + sc_mul(&result.m_s[0], &m_s[0], &v.m_s[0]); return result; } + scalar_t& operator*=(const scalar_t& v) + { + sc_mul(&m_s[0], &m_s[0], &v.m_s[0]); + return *this; + } + scalar_t reciprocal() const { scalar_t result; @@ -420,6 +438,14 @@ struct scalar_t return operator*(v.reciprocal()); } + scalar_t& operator/=(const scalar_t& v) + { + scalar_t reciprocal; + sc_invert(&reciprocal.m_s[0], &v.m_s[0]); + sc_mul(&m_s[0], &m_s[0], &reciprocal.m_s[0]); + return *this; + } + bool operator==(const scalar_t& rhs) const { return @@ -429,6 +455,15 @@ struct scalar_t m_u64[3] == rhs.m_u64[3]; } + bool operator!=(const scalar_t& rhs) const + { + return + m_u64[0] != rhs.m_u64[0] || + m_u64[1] != rhs.m_u64[1] || + m_u64[2] != rhs.m_u64[2] || + m_u64[3] != rhs.m_u64[3]; + } + bool operator<(const scalar_t& rhs) const { if (m_u64[3] < rhs.m_u64[3]) return true; @@ -462,7 +497,7 @@ struct scalar_t }; // struct scalar_t - //__declspec(align(32)) +//__declspec(align(32)) struct point_t { // A point(x, y) is represented in extended homogeneous coordinates (X, Y, Z, T) @@ -579,7 +614,7 @@ static const scalar_t scalar_256m1 = { 0xffffffffffffffff, 0xffffffffffffffff, */ #define TEST(test_name_a, test_name_b) \ static bool test_name_a ## _ ## test_name_b(); \ - static test_keeper_t test_name_a ## _ ## test_name_b ## keeper(STR(COMBINE(test_name_a ## _, test_name_b)), & test_name_a ## _ ## test_name_b); \ + static test_keeper_t test_name_a ## _ ## test_name_b ## _ ## keeper(STR(COMBINE(test_name_a ## _, test_name_b)), & test_name_a ## _ ## test_name_b); \ static bool test_name_a ## _ ## test_name_b() #define ASSERT_TRUE(expr) CHECK_AND_ASSERT_MES(expr, false, "This is not true: " #expr) #define ASSERT_FALSE(expr) CHECK_AND_ASSERT_MES((expr) == false, false, "This is not false: " #expr) @@ -596,6 +631,8 @@ struct test_keeper_t }; + + // // Tests // @@ -763,9 +800,10 @@ TEST(crypto, scalar_basics) std::cout << "0 = " << zero << std::endl; std::cout << "1 = " << one << std::endl; std::cout << "L = " << scalar_L << std::endl; - std::cout << "Lm1 = " << scalar_Lm1 << std::endl; + std::cout << "L-1 = " << scalar_Lm1 << std::endl; std::cout << "P = " << scalar_P << std::endl; - std::cout << "Pm1 = " << scalar_Pm1 << std::endl; + std::cout << "P-1 = " << scalar_Pm1 << std::endl; + std::cout << std::endl; // check rolling over L for scalars arithmetics ASSERT_EQ(scalar_Lm1 + 1, 0); @@ -773,18 +811,54 @@ TEST(crypto, scalar_basics) ASSERT_EQ(scalar_Lm1 * 2, scalar_Lm1 - 1); // (L - 1) * 2 = L + L - 2 = (L - 1) - 1 (mod L) ASSERT_EQ(scalar_Lm1 * 100, scalar_Lm1 - 99); ASSERT_EQ(scalar_Lm1 * scalar_Lm1, 1); // (L - 1) * (L - 1) = L*L - 2L + 1 = 1 (mod L) + ASSERT_EQ(scalar_Lm1 * (scalar_Lm1 - 1) * scalar_Lm1, scalar_Lm1 - 1); + ASSERT_EQ(scalar_L * scalar_L, 0); - std::cout << std::endl; + ASSERT_EQ(scalar_t(3) / scalar_Lm1, scalar_t(3) * scalar_Lm1); // because (L - 1) ^ 2 = 1 + return true; +} +TEST(crypto, sc_mul_performance) +{ + std::vector scalars(100000); + for (auto& s : scalars) + s.make_random(); - scalar_t a = 2; - a = a / 2; - std::cout << "2 / 2 = " << a << std::endl; - a = scalar_Lm1 / 2; - std::cout << "L-1 / 2 = " << a << std::endl; - a = a * 2; - std::cout << "L-1 / 2 * 2 = " << a << std::endl; + scalar_t m = 1; + + TIME_MEASURE_START(t); + for (auto& s : scalars) + m *= s; + TIME_MEASURE_FINISH(t); + + std::cout << m << std::endl; + + LOG_PRINT_L0("sc_mul: " << std::fixed << std::setprecision(3) << t / 1000.0 << " ms"); + + return true; +} + +TEST(crypto, scalar_arithmetic_assignment) +{ + std::vector scalars(1000); + scalar_t mm = 1, sum = 0; + for (auto& s : scalars) + { + s.make_random(); + mm /= s; + sum += s; + } + ASSERT_TRUE(!mm.is_zero() && mm != scalar_t(1)); + ASSERT_TRUE(!sum.is_zero()); + std::shuffle(scalars.begin(), scalars.end(), crypto::uniform_random_bit_generator()); + for (auto& s : scalars) + { + mm *= s; + sum -= s; + } + ASSERT_EQ(mm, 1); + ASSERT_EQ(sum, 0); return true; } @@ -854,10 +928,14 @@ int crypto_tests() for (size_t i = 0; i < g_tests.size(); ++i) { auto& test = g_tests[i]; + TIME_MEASURE_START(runtime); bool r = test.second(); + TIME_MEASURE_FINISH(runtime); + uint64_t runtime_ms = runtime / 1000; + uint64_t runtime_mcs = runtime % 1000; if (r) { - LOG_PRINT_GREEN(" " << std::setw(40) << std::left << test.first << "OK", LOG_LEVEL_0); + LOG_PRINT_GREEN(" " << std::setw(40) << std::left << test.first << "OK [" << runtime_ms << "." << std::setw(3) << std::setfill('0') << runtime_mcs << " ms]", LOG_LEVEL_0); } else { @@ -873,7 +951,7 @@ int crypto_tests() return 0; } - LOG_PRINT_RED_L0(ENDL, LOG_LEVEL_0); + LOG_PRINT_RED_L0(ENDL); LOG_PRINT_RED_L0(ENDL << "Failed tests:"); for (size_t i : failed_tests) {