Source code

Revision control

Copy as Markdown

Other Tools

/*
* Copyright © 2023, VideoLAN and dav1d authors
* Copyright © 2023, Loongson Technology Corporation Limited
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "loongson_asm.S"
const min_prob
.short 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
endconst
.macro decode_symbol_adapt w
addi.d sp, sp, -48
addi.d a4, a0, 24
vldrepl.h vr0, a4, 0 //rng
fst.s f0, sp, 0 //val==0
vld vr1, a1, 0 //cdf
.if \w == 16
li.w t4, 16
vldx vr11, a1, t4
.endif
addi.d a6, a0, 16
vldrepl.d vr2, a6, 0 //dif
addi.d t0, a0, 32
ld.w t1, t0, 0 //allow_update_cdf
la.local t2, min_prob
addi.d t2, t2, 32
addi.w t3, a2, 1
slli.w t3, t3, 1
sub.d t2, t2, t3
vld vr3, t2, 0 //min_prob
.if \w == 16
vldx vr13, t2, t4
.endif
vsrli.h vr4, vr0, 8 //r = s->rng >> 8
vslli.h vr4, vr4, 8 //r << 8
vsrli.h vr5, vr1, 6
vslli.h vr5, vr5, 7
.if \w == 16
vsrli.h vr15, vr11, 6
vslli.h vr15, vr15, 7
.endif
vmuh.hu vr5, vr4, vr5
vadd.h vr5, vr5, vr3 //v
.if \w == 16
vmuh.hu vr15, vr4, vr15
vadd.h vr15, vr15, vr13
.endif
addi.d t8, sp, 4
vst vr5, t8, 0 //store v
.if \w == 16
vstx vr15, t8, t4
.endif
vreplvei.h vr20, vr2, 3 //c
vssub.hu vr6, vr5, vr20 //c >=v
vseqi.h vr6, vr6, 0
.if \w == 16
vssub.hu vr16, vr15, vr20 //c >=v
vseqi.h vr16, vr16, 0
vpickev.b vr21, vr16, vr6
.endif
.if \w <= 8
vmskltz.h vr10, vr6
.else
vmskltz.b vr10, vr21
.endif
beqz t1, .renorm\()\w
// update_cdf
alsl.d t1, a2, a1, 1
ld.h t2, t1, 0 //count
srli.w t3, t2, 4 //count >> 4
addi.w t3, t3, 4
li.w t5, 2
sltu t5, t5, a2
add.w t3, t3, t5 //rate
sltui t5, t2, 32
add.w t2, t2, t5 //count + (count < 32)
vreplgr2vr.h vr9, t3
vseq.h vr7, vr7, vr7
vavgr.hu vr5, vr6, vr7 //i >= val ? -1 : 32768
vsub.h vr5, vr5, vr1
vsub.h vr8, vr1, vr6
.if \w == 16
vavgr.hu vr15, vr16, vr7
vsub.h vr15, vr15, vr11
vsub.h vr18, vr11, vr16
.endif
vsra.h vr5, vr5, vr9
vadd.h vr8, vr8, vr5
.if \w == 4
fst.d f8, a1, 0
.else
vst vr8, a1, 0
.endif
.if \w == 16
vsra.h vr15, vr15, vr9
vadd.h vr18, vr18, vr15
vstx vr18, a1, t4
.endif
st.h t2, t1, 0
.renorm\()\w:
vpickve2gr.h t3, vr10, 0
ctz.w a7, t3 // ret
alsl.d t3, a7, t8, 1
ld.hu t4, t3, 0 // v
addi.d t3, t3, -2
ld.hu t5, t3, 0 // u
sub.w t5, t5, t4 // rng
slli.d t4, t4, 48
vpickve2gr.d t6, vr2, 0
sub.d t6, t6, t4 // dif
clz.w t4, t5 // d
xori t4, t4, 16 // d
sll.d t6, t6, t4
addi.d a5, a0, 28 // cnt
ld.w t0, a5, 0
sll.w t5, t5, t4
sub.w t7, t0, t4 // cnt-d
st.w t5, a4, 0 // store rng
bgeu t0, t4, 9f
// refill
ld.d t0, a0, 0 // buf_pos
ld.d t1, a0, 8 // buf_end
addi.d t2, t0, 8
bltu t1, t2, 2f
ld.d t3, t0, 0 // next_bits
addi.w t1, t7, -48 // shift_bits = cnt + 16 (- 64)
nor t3, t3, t3
sub.w t2, zero, t1
revb.d t3, t3 // next_bits = bswap(next_bits)
srli.w t2, t2, 3 // num_bytes_read
srl.d t3, t3, t1 // next_bits >>= (shift_bits & 63)
b 3f
1:
addi.w t3, t7, -48
srl.d t3, t3, t3 // pad with ones
b 4f
2:
bgeu t0, t1, 1b
ld.d t3, t1, -8 // next_bits
sub.w t2, t2, t1
sub.w t1, t1, t0 // num_bytes_left
slli.w t2, t2, 3
srl.d t3, t3, t2
addi.w t2, t7, -48
nor t3, t3, t3
sub.w t4, zero, t2
revb.d t3, t3
srli.w t4, t4, 3
srl.d t3, t3, t2
sltu t2, t1, t4
maskeqz t1, t1, t2
masknez t2, t4, t2
or t2, t2, t1 // num_bytes_read
3:
slli.w t1, t2, 3
add.d t0, t0, t2
add.w t7, t7, t1 // cnt += num_bits_read
st.d t0, a0, 0
4:
or t6, t6, t3 // dif |= next_bits
9:
st.w t7, a5, 0 // store cnt
st.d t6, a6, 0 // store dif
move a0, a7
addi.d sp, sp, 48
.endm
function msac_decode_symbol_adapt4_lsx
decode_symbol_adapt 4
endfunc
function msac_decode_symbol_adapt8_lsx
decode_symbol_adapt 8
endfunc
function msac_decode_symbol_adapt16_lsx
decode_symbol_adapt 16
endfunc
function msac_decode_bool_lsx
ld.w t0, a0, 24 // rng
srli.w a1, a1, 6
ld.d t1, a0, 16 // dif
srli.w t2, t0, 8 // r >> 8
mul.w t2, t2, a1
ld.w a5, a0, 28 // cnt
srli.w t2, t2, 1
addi.w t2, t2, 4 // v
slli.d t3, t2, 48 // vw
sltu t4, t1, t3
move t8, t4 // ret
xori t4, t4, 1
maskeqz t6, t3, t4 // if (ret) vw
sub.d t6, t1, t6 // dif
slli.w t5, t2, 1
sub.w t5, t0, t5 // r - 2v
maskeqz t7, t5, t4 // if (ret) r - 2v
add.w t5, t2, t7 // v(rng)
// renorm
clz.w t4, t5 // d
xori t4, t4, 16 // d
sll.d t6, t6, t4
sll.w t5, t5, t4
sub.w t7, a5, t4 // cnt-d
st.w t5, a0, 24 // store rng
bgeu a5, t4, 9f
// refill
ld.d t0, a0, 0 // buf_pos
ld.d t1, a0, 8 // buf_end
addi.d t2, t0, 8
bltu t1, t2, 2f
ld.d t3, t0, 0 // next_bits
addi.w t1, t7, -48 // shift_bits = cnt + 16 (- 64)
nor t3, t3, t3
sub.w t2, zero, t1
revb.d t3, t3 // next_bits = bswap(next_bits)
srli.w t2, t2, 3 // num_bytes_read
srl.d t3, t3, t1 // next_bits >>= (shift_bits & 63)
b 3f
1:
addi.w t3, t7, -48
srl.d t3, t3, t3 // pad with ones
b 4f
2:
bgeu t0, t1, 1b
ld.d t3, t1, -8 // next_bits
sub.w t2, t2, t1
sub.w t1, t1, t0 // num_bytes_left
slli.w t2, t2, 3
srl.d t3, t3, t2
addi.w t2, t7, -48
nor t3, t3, t3
sub.w t4, zero, t2
revb.d t3, t3
srli.w t4, t4, 3
srl.d t3, t3, t2
sltu t2, t1, t4
maskeqz t1, t1, t2
masknez t2, t4, t2
or t2, t2, t1 // num_bytes_read
3:
slli.w t1, t2, 3
add.d t0, t0, t2
add.w t7, t7, t1 // cnt += num_bits_read
st.d t0, a0, 0
4:
or t6, t6, t3 // dif |= next_bits
9:
st.w t7, a0, 28 // store cnt
st.d t6, a0, 16 // store dif
move a0, t8
endfunc
function msac_decode_bool_adapt_lsx
ld.hu a3, a1, 0 // cdf[0] /f
ld.w t0, a0, 24 // rng
ld.d t1, a0, 16 // dif
srli.w t2, t0, 8 // r >> 8
srli.w a7, a3, 6
mul.w t2, t2, a7
ld.w a4, a0, 32 // allow_update_cdf
ld.w a5, a0, 28 // cnt
srli.w t2, t2, 1
addi.w t2, t2, 4 // v
slli.d t3, t2, 48 // vw
sltu t4, t1, t3
move t8, t4 // bit
xori t4, t4, 1
maskeqz t6, t3, t4 // if (ret) vw
sub.d t6, t1, t6 // dif
slli.w t5, t2, 1
sub.w t5, t0, t5 // r - 2v
maskeqz t7, t5, t4 // if (ret) r - 2v
add.w t5, t2, t7 // v(rng)
beqz a4, .renorm
// update_cdf
ld.hu t0, a1, 2 // cdf[1]
srli.w t1, t0, 4
addi.w t1, t1, 4 // rate
sltui t2, t0, 32 // count < 32
add.w t0, t0, t2 // count + (count < 32)
sub.w a3, a3, t8 // cdf[0] -= bit
slli.w t4, t8, 15
sub.w t7, a3, t4 // cdf[0] - bit - 32768
sra.w t7, t7, t1 // (cdf[0] - bit - 32768) >> rate
sub.w t7, a3, t7 // cdf[0]
st.h t7, a1, 0
st.h t0, a1, 2
.renorm:
clz.w t4, t5 // d
xori t4, t4, 16 // d
sll.d t6, t6, t4
sll.w t5, t5, t4
sub.w t7, a5, t4 // cnt-d
st.w t5, a0, 24 // store rng
bgeu a5, t4, 9f
// refill
ld.d t0, a0, 0 // buf_pos
ld.d t1, a0, 8 // buf_end
addi.d t2, t0, 8
bltu t1, t2, 2f
ld.d t3, t0, 0 // next_bits
addi.w t1, t7, -48 // shift_bits = cnt + 16 (- 64)
nor t3, t3, t3
sub.w t2, zero, t1
revb.d t3, t3 // next_bits = bswap(next_bits)
srli.w t2, t2, 3 // num_bytes_read
srl.d t3, t3, t1 // next_bits >>= (shift_bits & 63)
b 3f
1:
addi.w t3, t7, -48
srl.d t3, t3, t3 // pad with ones
b 4f
2:
bgeu t0, t1, 1b
ld.d t3, t1, -8 // next_bits
sub.w t2, t2, t1
sub.w t1, t1, t0 // num_bytes_left
slli.w t2, t2, 3
srl.d t3, t3, t2
addi.w t2, t7, -48
nor t3, t3, t3
sub.w t4, zero, t2
revb.d t3, t3
srli.w t4, t4, 3
srl.d t3, t3, t2
sltu t2, t1, t4
maskeqz t1, t1, t2
masknez t2, t4, t2
or t2, t2, t1 // num_bytes_read
3:
slli.w t1, t2, 3
add.d t0, t0, t2
add.w t7, t7, t1 // cnt += num_bits_read
st.d t0, a0, 0
4:
or t6, t6, t3 // dif |= next_bits
9:
st.w t7, a0, 28 // store cnt
st.d t6, a0, 16 // store dif
move a0, t8
endfunc