diff --git a/Makefile b/Makefile index 9f09b3d..646ca61 100644 --- a/Makefile +++ b/Makefile @@ -9,13 +9,13 @@ build: build-huffman build-arithmetic build-range build-rle build-huffman: - g++ -std=c++17 -O2 -Wall -Wextra -Werror -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/huffman/cpp/main.cpp -o algorithms/huffman/cpp/huffman_cpp + g++ -std=c++17 -O2 -Wall -Wextra -Werror -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/shared/cpp/src/frequency_table.cpp algorithms/huffman/cpp/main.cpp -o algorithms/huffman/cpp/huffman_cpp go build -o algorithms/huffman/go/huffman_go ./algorithms/huffman/go/cmd cargo build --manifest-path algorithms/huffman/rust/Cargo.toml --bin huffman_rust --release cp algorithms/huffman/rust/target/release/huffman_rust algorithms/huffman/rust/huffman_rust build-arithmetic: - g++ -std=c++17 -O2 -Wall -Wextra -Werror -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/arithmetic/cpp/main.cpp -o algorithms/arithmetic/cpp/arithmetic_cpp + g++ -std=c++17 -O2 -Wall -Wextra -Werror -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/shared/cpp/src/frequency_table.cpp algorithms/arithmetic/cpp/main.cpp -o algorithms/arithmetic/cpp/arithmetic_cpp go build -o algorithms/arithmetic/go/arithmetic_go ./algorithms/arithmetic/go/cmd cargo build --manifest-path algorithms/arithmetic/rust/Cargo.toml --bin arithmetic_rust --release cp algorithms/arithmetic/rust/target/release/arithmetic_rust algorithms/arithmetic/rust/arithmetic_rust @@ -40,7 +40,7 @@ test: test-data \ test-conformance test-cli-smoke test-shared-cpp: - g++ -std=c++17 -O2 -Wall -Wextra -Werror -DCOMPRESSKIT_NO_MAIN -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/huffman/cpp/main.cpp algorithms/arithmetic/cpp/main.cpp algorithms/range/cpp/main.cpp algorithms/rle/cpp/main.cpp algorithms/shared/cpp/tests/test_lifecycle.cpp -o algorithms/shared/cpp/tests/test_lifecycle + g++ -std=c++17 -O2 -Wall -Wextra -Werror -DCOMPRESSKIT_NO_MAIN -Ialgorithms/shared/cpp/include algorithms/shared/cpp/src/buffer_api.cpp algorithms/shared/cpp/src/cli_launcher.cpp algorithms/shared/cpp/src/frequency_table.cpp algorithms/huffman/cpp/main.cpp algorithms/arithmetic/cpp/main.cpp algorithms/range/cpp/main.cpp algorithms/rle/cpp/main.cpp algorithms/shared/cpp/tests/test_lifecycle.cpp -o algorithms/shared/cpp/tests/test_lifecycle ./algorithms/shared/cpp/tests/test_lifecycle test-shared-go: diff --git a/algorithms/arithmetic/cpp/main.cpp b/algorithms/arithmetic/cpp/main.cpp index 4782d33..8001a42 100644 --- a/algorithms/arithmetic/cpp/main.cpp +++ b/algorithms/arithmetic/cpp/main.cpp @@ -5,6 +5,7 @@ #include #include "compresskit/buffer_api.hpp" +#include "compresskit/frequency_table.hpp" class BitWriter { public: @@ -242,10 +243,17 @@ static std::vector build_frequencies_from_file(const std::string& inpu if (!in) { return freq; } - char c; - while (in.get(c)) { - unsigned char uc = static_cast(c); - freq[static_cast(uc)]++; + uint32_t overflow_symbol = 0; + const auto status = compresskit::accumulate_frequencies(in, freq, &overflow_symbol); + if (status == compresskit::FrequencyCountStatus::IO_ERROR) { + std::cerr << "Failed to read input file\n"; + freq.clear(); + return freq; + } + if (status == compresskit::FrequencyCountStatus::OVERFLOW) { + std::cerr << "Frequency overflow for symbol " << overflow_symbol << "\n"; + freq.clear(); + return freq; } freq[EOF_SYMBOL] = 1; scale_frequencies(freq); @@ -266,30 +274,20 @@ static std::vector build_cumulative(const std::vector& freq) } static void write_frequencies(std::ostream& out, const std::vector& freq) { - uint32_t count = static_cast(freq.size()); - out.write(reinterpret_cast(&count), sizeof(count)); - for (uint32_t v : freq) { - out.write(reinterpret_cast(&v), sizeof(v)); - } + compresskit::write_frequency_table(out, freq); } static bool read_frequencies(std::istream& in, std::vector& freq) { uint32_t count = 0; - in.read(reinterpret_cast(&count), sizeof(count)); - if (!in) { + const auto status = compresskit::read_frequency_table(in, freq, SYMBOL_LIMIT, &count); + if (status == compresskit::FrequencyTableReadStatus::TRUNCATED) { std::cerr << "Failed to read frequency table\n"; return false; } - if (count != SYMBOL_LIMIT) { + if (status == compresskit::FrequencyTableReadStatus::BAD_COUNT) { std::cerr << "Bad frequency table size: " << count << "\n"; return false; } - freq.assign(count, 0); - in.read(reinterpret_cast(freq.data()), freq.size() * sizeof(uint32_t)); - if (!in) { - std::cerr << "Failed to read frequency table\n"; - return false; - } return true; } @@ -305,8 +303,10 @@ static bool compress_file(const std::string& input_path, const std::string& outp } } } - std::vector freq = build_frequencies_from_file(input_path); + if (freq.empty()) { + return false; + } std::vector cumulative = build_cumulative(freq); std::ifstream in(input_path, std::ios::binary); diff --git a/algorithms/huffman/cpp/main.cpp b/algorithms/huffman/cpp/main.cpp index c3d4afd..9572d46 100644 --- a/algorithms/huffman/cpp/main.cpp +++ b/algorithms/huffman/cpp/main.cpp @@ -7,6 +7,7 @@ #include #include "compresskit/buffer_api.hpp" +#include "compresskit/frequency_table.hpp" class BitWriter { public: @@ -179,40 +180,37 @@ static std::vector build_frequencies_from_file(const std::string& inpu if (!in) { return freq; } - char c; - while (in.get(c)) { - unsigned char uc = static_cast(c); - freq[static_cast(uc)]++; + uint32_t overflow_symbol = 0; + const auto status = compresskit::accumulate_frequencies(in, freq, &overflow_symbol); + if (status == compresskit::FrequencyCountStatus::IO_ERROR) { + std::cerr << "Failed to read input file\n"; + freq.clear(); + return freq; + } + if (status == compresskit::FrequencyCountStatus::OVERFLOW) { + std::cerr << "Frequency overflow for symbol " << overflow_symbol << "\n"; + freq.clear(); + return freq; } freq[EOF_SYMBOL] = 1; return freq; } static void write_frequencies(std::ostream& out, const std::vector& freq) { - uint32_t count = static_cast(freq.size()); - out.write(reinterpret_cast(&count), sizeof(count)); - for (uint32_t v : freq) { - out.write(reinterpret_cast(&v), sizeof(v)); - } + compresskit::write_frequency_table(out, freq); } static bool read_frequencies(std::istream& in, std::vector& freq) { uint32_t count = 0; - in.read(reinterpret_cast(&count), sizeof(count)); - if (!in) { + const auto status = compresskit::read_frequency_table(in, freq, SYMBOL_LIMIT, &count); + if (status == compresskit::FrequencyTableReadStatus::TRUNCATED) { std::cerr << "Failed to read frequency table\n"; return false; } - if (count != SYMBOL_LIMIT) { + if (status == compresskit::FrequencyTableReadStatus::BAD_COUNT) { std::cerr << "Bad frequency table size: " << count << "\n"; return false; } - freq.assign(count, 0); - in.read(reinterpret_cast(freq.data()), freq.size() * sizeof(uint32_t)); - if (!in) { - std::cerr << "Failed to read frequency table\n"; - return false; - } return true; } @@ -230,6 +228,9 @@ static bool compress_file(const std::string& input_path, const std::string& outp } std::vector freq = build_frequencies_from_file(input_path); + if (freq.empty()) { + return false; + } UniqueNode root(build_tree(freq)); // RAII: automatic cleanup std::vector codes(SYMBOL_LIMIT); std::string prefix; diff --git a/algorithms/shared/cpp/include/compresskit/frequency_table.hpp b/algorithms/shared/cpp/include/compresskit/frequency_table.hpp new file mode 100644 index 0000000..4faa14c --- /dev/null +++ b/algorithms/shared/cpp/include/compresskit/frequency_table.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include + +namespace compresskit { + +enum class FrequencyTableReadStatus { + OK = 0, + TRUNCATED, + BAD_COUNT, +}; + +enum class FrequencyCountStatus { + OK = 0, + IO_ERROR, + OVERFLOW, +}; + +bool write_frequency_table(std::ostream& out, const std::vector& freq); + +FrequencyTableReadStatus read_frequency_table(std::istream& in, std::vector& freq, + uint32_t expected_count, + uint32_t* actual_count = nullptr); + +FrequencyCountStatus accumulate_frequencies(std::istream& in, std::vector& freq, + uint32_t* overflow_symbol = nullptr); + +} // namespace compresskit diff --git a/algorithms/shared/cpp/src/frequency_table.cpp b/algorithms/shared/cpp/src/frequency_table.cpp new file mode 100644 index 0000000..41fa4e2 --- /dev/null +++ b/algorithms/shared/cpp/src/frequency_table.cpp @@ -0,0 +1,95 @@ +#include "compresskit/frequency_table.hpp" + +#include +#include + +namespace compresskit { +namespace { + +bool write_u32_le(std::ostream& out, uint32_t value) { + const std::array bytes = { + static_cast(value & 0xFFu), + static_cast((value >> 8) & 0xFFu), + static_cast((value >> 16) & 0xFFu), + static_cast((value >> 24) & 0xFFu), + }; + out.write(bytes.data(), static_cast(bytes.size())); + return static_cast(out); +} + +bool read_u32_le(std::istream& in, uint32_t& value) { + std::array bytes{}; + in.read(reinterpret_cast(bytes.data()), static_cast(bytes.size())); + if (!in) { + return false; + } + value = static_cast(bytes[0]) | (static_cast(bytes[1]) << 8) | + (static_cast(bytes[2]) << 16) | (static_cast(bytes[3]) << 24); + return true; +} + +} // namespace + +bool write_frequency_table(std::ostream& out, const std::vector& freq) { + if (!write_u32_le(out, static_cast(freq.size()))) { + return false; + } + for (uint32_t value : freq) { + if (!write_u32_le(out, value)) { + return false; + } + } + return true; +} + +FrequencyTableReadStatus read_frequency_table(std::istream& in, std::vector& freq, + uint32_t expected_count, uint32_t* actual_count) { + uint32_t count = 0; + if (!read_u32_le(in, count)) { + freq.clear(); + return FrequencyTableReadStatus::TRUNCATED; + } + if (actual_count) { + *actual_count = count; + } + if (expected_count != 0 && count != expected_count) { + freq.clear(); + return FrequencyTableReadStatus::BAD_COUNT; + } + + freq.assign(count, 0); + for (uint32_t& value : freq) { + if (!read_u32_le(in, value)) { + freq.clear(); + return FrequencyTableReadStatus::TRUNCATED; + } + } + return FrequencyTableReadStatus::OK; +} + +FrequencyCountStatus accumulate_frequencies(std::istream& in, std::vector& freq, + uint32_t* overflow_symbol) { + std::array buffer{}; + for (;;) { + in.read(reinterpret_cast(buffer.data()), static_cast(buffer.size())); + const std::streamsize read_count = in.gcount(); + for (std::streamsize i = 0; i < read_count; ++i) { + const uint32_t symbol = static_cast(buffer[static_cast(i)]); + if (freq[symbol] == std::numeric_limits::max()) { + if (overflow_symbol) { + *overflow_symbol = symbol; + } + return FrequencyCountStatus::OVERFLOW; + } + ++freq[symbol]; + } + if (in.eof()) { + return FrequencyCountStatus::OK; + } + if (!in) { + return FrequencyCountStatus::IO_ERROR; + } + } +} + +} // namespace compresskit diff --git a/algorithms/shared/cpp/tests/test_lifecycle.cpp b/algorithms/shared/cpp/tests/test_lifecycle.cpp index b775d62..700f812 100644 --- a/algorithms/shared/cpp/tests/test_lifecycle.cpp +++ b/algorithms/shared/cpp/tests/test_lifecycle.cpp @@ -1,10 +1,12 @@ #include #include #include +#include #include #include #include "compresskit/algorithms.hpp" +#include "compresskit/frequency_table.hpp" namespace { @@ -129,6 +131,65 @@ void test_decode_buffer_preserves_finish_retry_prefix() { assert(std::string(decoded.value.begin(), decoded.value.end()) == "uvwxyz"); } +void test_write_frequency_table_uses_little_endian_layout() { + std::ostringstream out(std::ios::binary); + const std::vector freq = {0x78563412u, 0x01020304u}; + + const bool ok = compresskit::write_frequency_table(out, freq); + assert(ok); + + const std::string bytes = out.str(); + const std::string expected( + "\x02\x00\x00\x00" + "\x12\x34\x56\x78" + "\x04\x03\x02\x01", + 12); + assert(bytes == expected); +} + +void test_read_frequency_table_decodes_little_endian_values() { + const std::string bytes( + "\x02\x00\x00\x00" + "\x12\x34\x56\x78" + "\x04\x03\x02\x01", + 12); + std::istringstream in(bytes, std::ios::binary); + std::vector freq; + uint32_t actual_count = 0; + + const auto status = compresskit::read_frequency_table(in, freq, 2, &actual_count); + + assert(status == compresskit::FrequencyTableReadStatus::OK); + assert(actual_count == 2); + assert((freq == std::vector{0x78563412u, 0x01020304u})); +} + +void test_read_frequency_table_reports_bad_count() { + const std::string bytes("\x02\x00\x00\x00", 4); + std::istringstream in(bytes, std::ios::binary); + std::vector freq; + uint32_t actual_count = 0; + + const auto status = compresskit::read_frequency_table(in, freq, 3, &actual_count); + + assert(status == compresskit::FrequencyTableReadStatus::BAD_COUNT); + assert(actual_count == 2); + assert(freq.empty()); +} + +void test_accumulate_frequencies_reports_overflow() { + std::vector freq(257, 0); + freq[0] = UINT32_MAX; + std::istringstream in(std::string(1, '\0'), std::ios::binary); + uint32_t overflow_symbol = UINT32_MAX; + + const auto status = compresskit::accumulate_frequencies(in, freq, &overflow_symbol); + + assert(status == compresskit::FrequencyCountStatus::OVERFLOW); + assert(overflow_symbol == 0); + assert(freq[0] == UINT32_MAX); +} + } // namespace int main() { @@ -145,6 +206,10 @@ int main() { test_encode_buffer_preserves_finish_retry_prefix(); test_decode_buffer_preserves_finish_retry_prefix(); + test_write_frequency_table_uses_little_endian_layout(); + test_read_frequency_table_decodes_little_endian_values(); + test_read_frequency_table_reports_bad_count(); + test_accumulate_frequencies_reports_overflow(); return 0; }