Source code

Revision control

Copy as Markdown

Other Tools

// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "lib/jxl/enc_icc_codec.h"
#include <jxl/memory_manager.h>
#include <cstdint>
#include <limits>
#include <map>
#include <vector>
#include "lib/jxl/base/status.h"
#include "lib/jxl/enc_ans.h"
#include "lib/jxl/enc_aux_out.h"
#include "lib/jxl/fields.h"
#include "lib/jxl/icc_codec_common.h"
#include "lib/jxl/padded_bytes.h"
namespace jxl {
namespace {
// Unshuffles or de-interleaves bytes, for example with width 2, turns
// "AaBbCcDc" into "ABCDabcd", this for example de-interleaves UTF-16 bytes into
// first all the high order bytes, then all the low order bytes.
// Transposes a matrix of width columns and ceil(size / width) rows. There are
// size elements, size may be < width * height, if so the
// last elements of the bottom row are missing, the missing spots are
// transposed along with the filled spots, and the result has the missing
// elements at the bottom of the rightmost column. The input is the input matrix
// in scanline order, the output is the result matrix in scanline order, with
// missing elements skipped over (this may occur at multiple positions).
Status Unshuffle(JxlMemoryManager* memory_manager, uint8_t* data, size_t size,
size_t width) {
size_t height = (size + width - 1) / width; // amount of rows of input
PaddedBytes result(memory_manager);
JXL_ASSIGN_OR_RETURN(result,
PaddedBytes::WithInitialSpace(memory_manager, size));
// i = input index, j output index
size_t s = 0;
size_t j = 0;
for (size_t i = 0; i < size; i++) {
result[j] = data[i];
j += height;
if (j >= size) j = ++s;
}
for (size_t i = 0; i < size; i++) {
data[i] = result[i];
}
return true;
}
// This is performed by the encoder, the encoder must be able to encode any
// random byte stream (not just byte streams that are a valid ICC profile), so
// an error returned by this function is an implementation error.
Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num,
const uint8_t* data, size_t size, size_t* pos,
PaddedBytes* result) {
JXL_RETURN_IF_ERROR(CheckOutOfBounds(*pos, num, size));
JxlMemoryManager* memory_manager = result->memory_manager();
// Required by the specification, see decoder. stride * 4 must be < *pos.
if (!*pos || ((*pos - 1u) >> 2u) < stride) {
return JXL_FAILURE("Invalid stride");
}
if (*pos < stride * 4) return JXL_FAILURE("Too large stride");
size_t start = result->size();
for (size_t i = 0; i < num; i++) {
uint8_t predicted =
LinearPredictICCValue(data, *pos, i, stride, width, order);
JXL_RETURN_IF_ERROR(result->push_back(data[*pos + i] - predicted));
}
*pos += num;
if (width > 1) {
JXL_RETURN_IF_ERROR(
Unshuffle(memory_manager, result->data() + start, num, width));
}
return true;
}
inline Status EncodeVarInt(uint64_t value, PaddedBytes* data) {
size_t pos = data->size();
JXL_RETURN_IF_ERROR(data->resize(data->size() + 9));
size_t output_size = data->size();
uint8_t* output = data->data();
// While more than 7 bits of data are left,
// store 7 bits and set the next byte flag
while (value > 127) {
// TODO(eustas): should it be `<` ?
JXL_ENSURE(pos <= output_size);
// |128: Set the next byte flag
output[pos++] = (static_cast<uint8_t>(value & 127)) | 128;
// Remove the seven bits we just wrote
value >>= 7;
}
// TODO(eustas): should it be `<` ?
JXL_ENSURE(pos <= output_size);
output[pos++] = static_cast<uint8_t>(value & 127);
return data->resize(pos);
}
constexpr size_t kSizeLimit = std::numeric_limits<uint32_t>::max() >> 2;
} // namespace
// Outputs a transformed form of the given icc profile. The result itself is
// not particularly smaller than the input data in bytes, but it will be in a
// form that is easier to compress (more zeroes, ...) and will compress better
// with brotli.
Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) {
JxlMemoryManager* memory_manager = result->memory_manager();
PaddedBytes commands{memory_manager};
PaddedBytes data{memory_manager};
static_assert(sizeof(size_t) >= 4, "size_t is too short");
// Fuzzer expects that PredictICC can accept any input,
// but 1GB should be enough for any purpose.
if (size > kSizeLimit) {
return JXL_FAILURE("ICC profile is too large");
}
JXL_RETURN_IF_ERROR(EncodeVarInt(size, result));
// Header
PaddedBytes header{memory_manager};
JXL_RETURN_IF_ERROR(header.append(ICCInitialHeaderPrediction(size)));
for (size_t i = 0; i < kICCHeaderSize && i < size; i++) {
ICCPredictHeader(icc, size, header.data(), i);
JXL_RETURN_IF_ERROR(data.push_back(icc[i] - header[i]));
}
if (size <= kICCHeaderSize) {
JXL_RETURN_IF_ERROR(EncodeVarInt(0, result)); // 0 commands
for (uint8_t b : data) {
JXL_RETURN_IF_ERROR(result->push_back(b));
}
return true;
}
std::vector<Tag> tags;
std::vector<size_t> tagstarts;
std::vector<size_t> tagsizes;
std::map<size_t, size_t> tagmap;
// Tag list
size_t pos = kICCHeaderSize;
if (pos + 4 <= size) {
uint64_t numtags = DecodeUint32(icc, size, pos);
pos += 4;
JXL_RETURN_IF_ERROR(EncodeVarInt(numtags + 1, &commands));
uint64_t prevtagstart = kICCHeaderSize + numtags * 12;
uint32_t prevtagsize = 0;
for (size_t i = 0; i < numtags; i++) {
if (pos + 12 > size) break;
Tag tag = DecodeKeyword(icc, size, pos + 0);
uint32_t tagstart = DecodeUint32(icc, size, pos + 4);
uint32_t tagsize = DecodeUint32(icc, size, pos + 8);
pos += 12;
tags.push_back(tag);
tagstarts.push_back(tagstart);
tagsizes.push_back(tagsize);
tagmap[tagstart] = tags.size() - 1;
uint8_t tagcode = kCommandTagUnknown;
for (size_t j = 0; j < kNumTagStrings; j++) {
if (tag == *kTagStrings[j]) {
tagcode = j + kCommandTagStringFirst;
break;
}
}
if (tag == kRtrcTag && pos + 24 < size) {
bool ok = true;
ok &= DecodeKeyword(icc, size, pos + 0) == kGtrcTag;
ok &= DecodeKeyword(icc, size, pos + 12) == kBtrcTag;
if (ok) {
for (size_t kk = 0; kk < 8; kk++) {
if (icc[pos - 8 + kk] != icc[pos + 4 + kk]) ok = false;
if (icc[pos - 8 + kk] != icc[pos + 16 + kk]) ok = false;
}
}
if (ok) {
tagcode = kCommandTagTRC;
pos += 24;
i += 2;
}
}
if (tag == kRxyzTag && pos + 24 < size) {
bool ok = true;
ok &= DecodeKeyword(icc, size, pos + 0) == kGxyzTag;
ok &= DecodeKeyword(icc, size, pos + 12) == kBxyzTag;
uint32_t offsetr = tagstart;
uint32_t offsetg = DecodeUint32(icc, size, pos + 4);
uint32_t offsetb = DecodeUint32(icc, size, pos + 16);
uint32_t sizer = tagsize;
uint32_t sizeg = DecodeUint32(icc, size, pos + 8);
uint32_t sizeb = DecodeUint32(icc, size, pos + 20);
ok &= sizer == 20;
ok &= sizeg == 20;
ok &= sizeb == 20;
ok &= (offsetg == offsetr + 20);
ok &= (offsetb == offsetr + 40);
if (ok) {
tagcode = kCommandTagXYZ;
pos += 24;
i += 2;
}
}
uint8_t command = tagcode;
uint64_t predicted_tagstart = prevtagstart + prevtagsize;
if (predicted_tagstart != tagstart) command |= kFlagBitOffset;
size_t predicted_tagsize = prevtagsize;
if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag ||
tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag ||
tag == kLumiTag) {
predicted_tagsize = 20;
}
if (predicted_tagsize != tagsize) command |= kFlagBitSize;
JXL_RETURN_IF_ERROR(commands.push_back(command));
if (tagcode == 1) {
JXL_RETURN_IF_ERROR(AppendKeyword(tag, &data));
}
if (command & kFlagBitOffset)
JXL_RETURN_IF_ERROR(EncodeVarInt(tagstart, &commands));
if (command & kFlagBitSize)
JXL_RETURN_IF_ERROR(EncodeVarInt(tagsize, &commands));
prevtagstart = tagstart;
prevtagsize = tagsize;
}
}
// Indicate end of tag list or varint indicating there's none
JXL_RETURN_IF_ERROR(commands.push_back(0));
// Main content
// The main content in a valid ICC profile contains tagged elements, with the
// tag types (4 letter names) given by the tag list above, and the tag list
// pointing to the start and indicating the size of each tagged element. It is
// allowed for tagged elements to overlap, e.g. the curve for R, G and B could
// all point to the same one.
Tag tag;
size_t tagstart = 0;
size_t tagsize = 0;
size_t clutstart = 0;
// Should always check tag_sane before doing math with tagsize.
const auto tag_sane = [&tagsize]() {
return (tagsize > 8) && (tagsize < kSizeLimit);
};
size_t last0 = pos;
// This loop appends commands to the output, processing some sub-section of a
// current tagged element each time. We need to keep track of the tagtype of
// the current element, and update it when we encounter the boundary of a
// next one.
// It is not required that the input data is a valid ICC profile, if the
// encoder does not recognize the data it will still be able to output bytes
// but will not predict as well.
while (pos <= size) {
size_t last1 = pos;
PaddedBytes commands_add{memory_manager};
PaddedBytes data_add{memory_manager};
// This means the loop brought the position beyond the tag end.
// If tagsize is nonsensical, any pos looks "ok-ish".
if ((pos > tagstart + tagsize) && (tagsize < kSizeLimit)) {
tag = {{0, 0, 0, 0}}; // nonsensical value
}
if (commands_add.empty() && data_add.empty() && tagmap.count(pos) &&
pos + 4 <= size) {
size_t index = tagmap[pos];
tag = DecodeKeyword(icc, size, pos);
tagstart = tagstarts[index];
tagsize = tagsizes[index];
if (tag == kMlucTag && tag_sane() && pos + tagsize <= size &&
icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 &&
icc[pos + 7] == 0) {
size_t num = tagsize - 8;
JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandTypeStartFirst + 3));
pos += 8;
JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandShuffle2));
JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add));
size_t start = data_add.size();
for (size_t i = 0; i < num; i++) {
JXL_RETURN_IF_ERROR(data_add.push_back(icc[pos]));
pos++;
}
JXL_RETURN_IF_ERROR(
Unshuffle(memory_manager, data_add.data() + start, num, 2));
}
if (tag == kCurvTag && tag_sane() && pos + tagsize <= size &&
icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 &&
icc[pos + 7] == 0) {
size_t num = tagsize - 8;
if (num > 16 && num < (1 << 28) && pos + num <= size && pos > 0) {
JXL_RETURN_IF_ERROR(
commands_add.push_back(kCommandTypeStartFirst + 5));
pos += 8;
JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandPredict));
int order = 1;
int width = 2;
int stride = width;
JXL_RETURN_IF_ERROR(
commands_add.push_back((order << 2) | (width - 1)));
JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add));
JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
size, &pos, &data_add));
}
}
}
if (tag == kMab_Tag || tag == kMba_Tag) {
Tag subTag = DecodeKeyword(icc, size, pos);
if (pos + 12 < size && (subTag == kCurvTag || subTag == kVcgtTag) &&
DecodeUint32(icc, size, pos + 4) == 0) {
uint32_t num = DecodeUint32(icc, size, pos + 8) * 2;
if (num > 16 && num < (1 << 28) && pos + 12 + num <= size) {
pos += 12;
last1 = pos;
JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandPredict));
int order = 1;
int width = 2;
int stride = width;
JXL_RETURN_IF_ERROR(
commands_add.push_back((order << 2) | (width - 1)));
JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add));
JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
size, &pos, &data_add));
}
}
if (pos == tagstart + 24 && pos + 4 < size) {
// Note that this value can be remembered for next iterations of the
// loop, so the "pos == clutstart" if below can trigger during a later
// iteration.
clutstart = tagstart + DecodeUint32(icc, size, pos);
}
if (pos == clutstart && clutstart + 16 < size) {
size_t numi = icc[tagstart + 8];
size_t numo = icc[tagstart + 9];
size_t width = icc[clutstart + 16];
size_t stride = width * numo;
size_t num = width * numo;
for (size_t i = 0; i < numi && clutstart + i < size; i++) {
num *= icc[clutstart + i];
}
if ((width == 1 || width == 2) && num > 64 && num < (1 << 28) &&
pos + num <= size && pos > stride * 4) {
JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandPredict));
int order = 1;
uint8_t flags =
(order << 2) | (width - 1) | (stride == width ? 0 : 16);
JXL_RETURN_IF_ERROR(commands_add.push_back(flags));
if (flags & 16) {
JXL_RETURN_IF_ERROR(EncodeVarInt(stride, &commands_add));
}
JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add));
JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
size, &pos, &data_add));
}
}
}
if (commands_add.empty() && data_add.empty() && tag == kGbd_Tag &&
tag_sane() && pos == tagstart + 8 && pos + tagsize - 8 <= size &&
pos > 16) {
size_t width = 4;
size_t order = 0;
size_t stride = width;
size_t num = tagsize - 8;
uint8_t flags = (order << 2) | (width - 1) | (stride == width ? 0 : 16);
JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandPredict));
JXL_RETURN_IF_ERROR(commands_add.push_back(flags));
if (flags & 16) {
JXL_RETURN_IF_ERROR(EncodeVarInt(stride, &commands_add));
}
JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add));
JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
size, &pos, &data_add));
}
if (commands_add.empty() && data_add.empty() && pos + 20 <= size) {
Tag subTag = DecodeKeyword(icc, size, pos);
if (subTag == kXyz_Tag && DecodeUint32(icc, size, pos + 4) == 0) {
JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandXYZ));
pos += 8;
for (size_t j = 0; j < 12; j++) {
JXL_RETURN_IF_ERROR(data_add.push_back(icc[pos++]));
}
}
}
if (commands_add.empty() && data_add.empty() && pos + 8 <= size) {
if (DecodeUint32(icc, size, pos + 4) == 0) {
Tag subTag = DecodeKeyword(icc, size, pos);
for (size_t i = 0; i < kNumTypeStrings; i++) {
if (subTag == *kTypeStrings[i]) {
JXL_RETURN_IF_ERROR(
commands_add.push_back(kCommandTypeStartFirst + i));
pos += 8;
break;
}
}
}
}
if (!(commands_add.empty() && data_add.empty()) || pos == size) {
if (last0 < last1) {
JXL_RETURN_IF_ERROR(commands.push_back(kCommandInsert));
JXL_RETURN_IF_ERROR(EncodeVarInt(last1 - last0, &commands));
while (last0 < last1) {
JXL_RETURN_IF_ERROR(data.push_back(icc[last0++]));
}
}
for (uint8_t b : commands_add) {
JXL_RETURN_IF_ERROR(commands.push_back(b));
}
for (uint8_t b : data_add) {
JXL_RETURN_IF_ERROR(data.push_back(b));
}
last0 = pos;
}
if (commands_add.empty() && data_add.empty()) {
pos++;
}
}
JXL_RETURN_IF_ERROR(EncodeVarInt(commands.size(), result));
for (uint8_t b : commands) {
JXL_RETURN_IF_ERROR(result->push_back(b));
}
for (uint8_t b : data) {
JXL_RETURN_IF_ERROR(result->push_back(b));
}
return true;
}
Status WriteICC(const Span<const uint8_t> icc, BitWriter* JXL_RESTRICT writer,
LayerType layer, AuxOut* JXL_RESTRICT aux_out) {
if (icc.empty()) return JXL_FAILURE("ICC must be non-empty");
JxlMemoryManager* memory_manager = writer->memory_manager();
PaddedBytes enc{memory_manager};
JXL_RETURN_IF_ERROR(PredictICC(icc.data(), icc.size(), &enc));
std::vector<std::vector<Token>> tokens(1);
JXL_RETURN_IF_ERROR(writer->WithMaxBits(128, layer, aux_out, [&] {
return U64Coder::Write(enc.size(), writer);
}));
for (size_t i = 0; i < enc.size(); i++) {
tokens[0].emplace_back(
ICCANSContext(i, i > 0 ? enc[i - 1] : 0, i > 1 ? enc[i - 2] : 0),
enc[i]);
}
HistogramParams params;
params.lz77_method = enc.size() < 4096 ? HistogramParams::LZ77Method::kOptimal
: HistogramParams::LZ77Method::kLZ77;
EntropyEncodingData code;
std::vector<uint8_t> context_map;
params.force_huffman = true;
JXL_ASSIGN_OR_RETURN(
size_t cost,
BuildAndEncodeHistograms(memory_manager, params, kNumICCContexts, tokens,
&code, &context_map, writer, layer, aux_out));
(void)cost;
JXL_RETURN_IF_ERROR(
WriteTokens(tokens[0], code, context_map, 0, writer, layer, aux_out));
return true;
}
} // namespace jxl