Skip to content

Replace several intrinsics with Julia equivalents #22202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions base/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ module FastMath

export @fastmath

import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast,
import Core.Intrinsics: sqrt_llvm, neg_float_fast,
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, rem_float_fast,
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast

Expand Down Expand Up @@ -264,9 +264,7 @@ end
pow_fast(x::Float32, y::Integer) = ccall("llvm.powi.f32", llvmcall, Float32, (Float32, Int32), x, y)
pow_fast(x::Float64, y::Integer) = ccall("llvm.powi.f64", llvmcall, Float64, (Float64, Int32), x, y)

# TODO: Change sqrt_llvm intrinsic to avoid nan checking; add nan
# checking to sqrt in math.jl; remove sqrt_llvm_fast intrinsic
sqrt_fast(x::FloatTypes) = sqrt_llvm_fast(x)
sqrt_fast(x::FloatTypes) = sqrt_llvm(x)

# libm

Expand Down
8 changes: 0 additions & 8 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,6 @@ add_tfunc(uitofp, 2, 2, bitcast_tfunc, 1)
add_tfunc(sitofp, 2, 2, bitcast_tfunc, 1)
add_tfunc(fptrunc, 2, 2, bitcast_tfunc, 1)
add_tfunc(fpext, 2, 2, bitcast_tfunc, 1)
## checked conversion ##
add_tfunc(checked_trunc_sint, 2, 2, bitcast_tfunc, 3)
add_tfunc(checked_trunc_uint, 2, 2, bitcast_tfunc, 3)
add_tfunc(check_top_bit, 1, 1, math_tfunc, 2)
## arithmetic ##
add_tfunc(neg_int, 1, 1, math_tfunc, 1)
add_tfunc(add_int, 2, 2, math_tfunc, 1)
Expand Down Expand Up @@ -502,7 +498,6 @@ add_tfunc(floor_llvm, 1, 1, math_tfunc, 10)
add_tfunc(trunc_llvm, 1, 1, math_tfunc, 10)
add_tfunc(rint_llvm, 1, 1, math_tfunc, 10)
add_tfunc(sqrt_llvm, 1, 1, math_tfunc, 20)
add_tfunc(sqrt_llvm_fast, 1, 1, math_tfunc, 20)
## same-type comparisons ##
cmp_tfunc(x::ANY, y::ANY) = Bool
add_tfunc(eq_int, 2, 2, cmp_tfunc, 1)
Expand Down Expand Up @@ -3747,13 +3742,10 @@ function is_pure_intrinsic(f::IntrinsicFunction)
return !(f === Intrinsics.pointerref || # this one is volatile
f === Intrinsics.pointerset || # this one is never effect-free
f === Intrinsics.llvmcall || # this one is never effect-free
f === Intrinsics.checked_trunc_sint ||
f === Intrinsics.checked_trunc_uint ||
f === Intrinsics.checked_sdiv_int ||
f === Intrinsics.checked_udiv_int ||
f === Intrinsics.checked_srem_int ||
f === Intrinsics.checked_urem_int ||
f === Intrinsics.check_top_bit ||
f === Intrinsics.sqrt_llvm ||
f === Intrinsics.cglobal) # cglobal throws an error for symbol-not-found
end
Expand Down
26 changes: 26 additions & 0 deletions base/int.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,34 @@ trailing_ones(x::Integer) = trailing_zeros(~x)
>>>(x::BitInteger, y::Int) =
select_value(0 <= y, x >>> unsigned(y), x << unsigned(-y))

function is_top_bit_set(x::BitInteger)
@_inline_meta
lshr_int(x, (sizeof(x) << 0x03) - 1) == rem(0x01, typeof(x))
end
function check_top_bit(x::BitInteger)
@_inline_meta
is_top_bit_set(x) && throw(InexactError())
x
end

## integer conversions ##

function checked_trunc_sint{To,From}(::Type{To}, x::From)
@_inline_meta
y = trunc_int(To, x)
back = sext_int(From, y)
x == back || throw(InexactError())
y
end

function checked_trunc_uint{To,From}(::Type{To}, x::From)
@_inline_meta
y = trunc_int(To, x)
back = zext_int(From, y)
x == back || throw(InexactError())
y
end

for to in BitInteger_types, from in (BitInteger_types..., Bool)
if !(to === from)
if to.size < from.size
Expand Down
6 changes: 4 additions & 2 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,10 @@ Compute sine and cosine of `x`, where `x` is in radians.
return res
end

sqrt(x::Float64) = sqrt_llvm(x)
sqrt(x::Float32) = sqrt_llvm(x)
@inline function sqrt(x::Union{Float32,Float64})
x < zero(x) && throw(DomainError())
sqrt_llvm(x)
end

"""
sqrt(x)
Expand Down
7 changes: 0 additions & 7 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,13 +913,6 @@ static void raise_exception_unless(jl_codectx_t &ctx, Value *cond, Value *exc)
raise_exception(ctx, exc, passBB);
}

// DO NOT PASS IN A CONST CONDITION!
static void raise_exception_if(jl_codectx_t &ctx, Value *cond, Value *exc)
{
raise_exception_unless(ctx, ctx.builder.CreateXor(cond, ConstantInt::get(T_int1,-1)),
exc);
}

static size_t dereferenceable_size(jl_value_t *jt) {
size_t size = 0;
if (jl_is_array_type(jt)) {
Expand Down
2 changes: 0 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6230,9 +6230,7 @@ static void init_julia_llvm_env(Module *m)
global_jlvalue_to_llvm("jl_emptytuple", &jl_emptytuple, m);
global_jlvalue_to_llvm("jl_diverror_exception", &jl_diverror_exception, m);
global_jlvalue_to_llvm("jl_undefref_exception", &jl_undefref_exception, m);
global_jlvalue_to_llvm("jl_domain_exception", &jl_domain_exception, m);
global_jlvalue_to_llvm("jl_overflow_exception", &jl_overflow_exception, m);
global_jlvalue_to_llvm("jl_inexact_exception", &jl_inexact_exception, m);

jlRTLD_DEFAULT_var =
new GlobalVariable(*m, T_pint8,
Expand Down
2 changes: 0 additions & 2 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -785,9 +785,7 @@ void jl_get_builtin_hooks(void)
jl_errorexception_type = (jl_datatype_t*)core("ErrorException");
jl_stackovf_exception = jl_new_struct_uninit((jl_datatype_t*)core("StackOverflowError"));
jl_diverror_exception = jl_new_struct_uninit((jl_datatype_t*)core("DivideError"));
jl_domain_exception = jl_new_struct_uninit((jl_datatype_t*)core("DomainError"));
jl_overflow_exception = jl_new_struct_uninit((jl_datatype_t*)core("OverflowError"));
jl_inexact_exception = jl_new_struct_uninit((jl_datatype_t*)core("InexactError"));
jl_undefref_exception = jl_new_struct_uninit((jl_datatype_t*)core("UndefRefError"));
jl_undefvarerror_type = (jl_datatype_t*)core("UndefVarError");
jl_interrupt_exception = jl_new_struct_uninit((jl_datatype_t*)core("InterruptException"));
Expand Down
39 changes: 1 addition & 38 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ static void jl_init_intrinsic_functions_codegen(Module *m)
float_func[trunc_llvm] = true;
float_func[rint_llvm] = true;
float_func[sqrt_llvm] = true;
float_func[sqrt_llvm_fast] = true;
}

extern "C"
Expand Down Expand Up @@ -499,24 +498,6 @@ static Value *generic_trunc(jl_codectx_t &ctx, Type *to, Value *x)
return ctx.builder.CreateTrunc(x, to);
}

static Value *generic_trunc_uchecked(jl_codectx_t &ctx, Type *to, Value *x)
{
Value *ans = ctx.builder.CreateTrunc(x, to);
Value *back = ctx.builder.CreateZExt(ans, x->getType());
raise_exception_unless(ctx, ctx.builder.CreateICmpEQ(back, x),
literal_pointer_val(ctx, jl_inexact_exception));
return ans;
}

static Value *generic_trunc_schecked(jl_codectx_t &ctx, Type *to, Value *x)
{
Value *ans = ctx.builder.CreateTrunc(x, to);
Value *back = ctx.builder.CreateSExt(ans, x->getType());
raise_exception_unless(ctx, ctx.builder.CreateICmpEQ(back, x),
literal_pointer_val(ctx, jl_inexact_exception));
return ans;
}

static Value *generic_sext(jl_codectx_t &ctx, Type *to, Value *x)
{
return ctx.builder.CreateSExt(x, to);
Expand Down Expand Up @@ -778,10 +759,6 @@ static jl_cgval_t emit_intrinsic(jl_codectx_t &ctx, intrinsic f, jl_value_t **ar
return generic_bitcast(ctx, argv);
case trunc_int:
return generic_cast(ctx, f, generic_trunc, argv, true, true);
case checked_trunc_uint:
return generic_cast(ctx, f, generic_trunc_uchecked, argv, true, true);
case checked_trunc_sint:
return generic_cast(ctx, f, generic_trunc_schecked, argv, true, true);
case sext_int:
return generic_cast(ctx, f, generic_sext, argv, true, true);
case zext_int:
Expand Down Expand Up @@ -1007,15 +984,6 @@ static Value *emit_untyped_intrinsic(jl_codectx_t &ctx, intrinsic f, Value **arg
literal_pointer_val(ctx, jl_diverror_exception));
return ctx.builder.CreateURem(x, y);

case check_top_bit:
// raise InexactError if argument's top bit is set
raise_exception_if(ctx,
ctx.builder.CreateTrunc(
ctx.builder.CreateLShr(x, ConstantInt::get(t, t->getPrimitiveSizeInBits() - 1)),
T_int1),
literal_pointer_val(ctx, jl_inexact_exception));
return x;

case eq_int: *newtyp = jl_bool_type; return ctx.builder.CreateICmpEQ(x, y);
case ne_int: *newtyp = jl_bool_type; return ctx.builder.CreateICmpNE(x, y);
case slt_int: *newtyp = jl_bool_type; return ctx.builder.CreateICmpSLT(x, y);
Expand Down Expand Up @@ -1150,12 +1118,7 @@ static Value *emit_untyped_intrinsic(jl_codectx_t &ctx, intrinsic f, Value **arg
Value *rintintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::rint, makeArrayRef(t));
return ctx.builder.CreateCall(rintintr, x);
}
case sqrt_llvm:
raise_exception_unless(ctx,
ctx.builder.CreateFCmpUGE(x, ConstantFP::get(t, 0.0)),
literal_pointer_val(ctx, jl_domain_exception));
// fall-through
case sqrt_llvm_fast: {
case sqrt_llvm: {
Value *sqrtintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::sqrt, makeArrayRef(t));
return ctx.builder.CreateCall(sqrtintr, x);
}
Expand Down
5 changes: 0 additions & 5 deletions src/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@
ADD_I(sitofp, 2) \
ADD_I(fptrunc, 2) \
ADD_I(fpext, 2) \
/* checked conversion */ \
ADD_I(checked_trunc_sint, 2) \
ADD_I(checked_trunc_uint, 2) \
ADD_I(check_top_bit, 1) \
/* checked arithmetic */ \
ADD_I(checked_sadd_int, 2) \
ADD_I(checked_uadd_int, 2) \
Expand All @@ -91,7 +87,6 @@
ADD_I(trunc_llvm, 1) \
ADD_I(rint_llvm, 1) \
ADD_I(sqrt_llvm, 1) \
ALIAS(sqrt_llvm_fast, sqrt_llvm) \
/* pointer access */ \
ADD_I(pointerref, 3) \
ADD_I(pointerset, 4) \
Expand Down
1 change: 0 additions & 1 deletion src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ jl_value_t *jl_segv_exception;
JL_DLLEXPORT jl_value_t *jl_diverror_exception;
JL_DLLEXPORT jl_value_t *jl_domain_exception;
JL_DLLEXPORT jl_value_t *jl_overflow_exception;
JL_DLLEXPORT jl_value_t *jl_inexact_exception;
JL_DLLEXPORT jl_value_t *jl_undefref_exception;
jl_value_t *jl_interrupt_exception;
jl_datatype_t *jl_boundserror_type;
Expand Down
2 changes: 0 additions & 2 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,7 @@ extern JL_DLLEXPORT jl_value_t *jl_stackovf_exception;
extern JL_DLLEXPORT jl_value_t *jl_memory_exception;
extern JL_DLLEXPORT jl_value_t *jl_readonlymemory_exception;
extern JL_DLLEXPORT jl_value_t *jl_diverror_exception;
extern JL_DLLEXPORT jl_value_t *jl_domain_exception;
extern JL_DLLEXPORT jl_value_t *jl_overflow_exception;
extern JL_DLLEXPORT jl_value_t *jl_inexact_exception;
extern JL_DLLEXPORT jl_value_t *jl_undefref_exception;
extern JL_DLLEXPORT jl_value_t *jl_interrupt_exception;
extern JL_DLLEXPORT jl_datatype_t *jl_boundserror_type;
Expand Down
4 changes: 0 additions & 4 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -807,10 +807,6 @@ JL_DLLEXPORT jl_value_t *jl_fptosi(jl_value_t *ty, jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_fptrunc(jl_value_t *ty, jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_fpext(jl_value_t *ty, jl_value_t *a);

JL_DLLEXPORT jl_value_t *jl_checked_trunc_sint(jl_value_t *ty, jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_checked_trunc_uint(jl_value_t *ty, jl_value_t *a);

JL_DLLEXPORT jl_value_t *jl_check_top_bit(jl_value_t *a);
JL_DLLEXPORT jl_value_t *jl_checked_sadd_int(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_checked_uadd_int(jl_value_t *a, jl_value_t *b);
JL_DLLEXPORT jl_value_t *jl_checked_ssub_int(jl_value_t *a, jl_value_t *b);
Expand Down
32 changes: 3 additions & 29 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,13 @@ static inline jl_value_t *jl_intrinsiclambda_u1(jl_value_t *ty, void *pa, unsign

typedef void (*intrinsic_cvt_t)(unsigned, void*, unsigned, void*);
typedef unsigned (*intrinsic_cvt_check_t)(unsigned, unsigned, void*);
#define cvt_iintrinsic_checked(LLVMOP, check_op, name) \
#define cvt_iintrinsic(LLVMOP, name) \
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *ty, jl_value_t *a) \
{ \
return jl_intrinsic_cvt(ty, a, #name, LLVMOP, check_op); \
return jl_intrinsic_cvt(ty, a, #name, LLVMOP); \
}
#define cvt_iintrinsic(LLVMOP, name) \
cvt_iintrinsic_checked(LLVMOP, NULL, name) \

static inline jl_value_t *jl_intrinsic_cvt(jl_value_t *ty, jl_value_t *a, const char *name, intrinsic_cvt_t op, intrinsic_cvt_check_t check_op)
static inline jl_value_t *jl_intrinsic_cvt(jl_value_t *ty, jl_value_t *a, const char *name, intrinsic_cvt_t op)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_value_t *aty = jl_typeof(a);
Expand All @@ -397,8 +395,6 @@ static inline jl_value_t *jl_intrinsic_cvt(jl_value_t *ty, jl_value_t *a, const
void *pa = jl_data_ptr(a);
unsigned isize = jl_datatype_size(aty);
unsigned osize = jl_datatype_size(ty);
if (check_op && check_op(isize, osize, pa))
jl_throw(jl_inexact_exception);
jl_value_t *newv = jl_gc_alloc(ptls, jl_datatype_size(ty), ty);
op(aty == (jl_value_t*)jl_bool_type ? 1 : isize * host_char_bit, pa,
osize * host_char_bit, jl_data_ptr(newv));
Expand Down Expand Up @@ -856,26 +852,6 @@ static inline int all_eq(char *p, char n, char v)
return 0;
return 1;
}
static unsigned check_trunc_sint(unsigned isize, unsigned osize, void *pa)
{
return !all_eq((char*)pa + osize, isize - osize, signbitbyte(pa, isize)); // TODO: assumes little-endian
}
cvt_iintrinsic_checked(LLVMTrunc, check_trunc_sint, checked_trunc_sint)
static unsigned check_trunc_uint(unsigned isize, unsigned osize, void *pa)
{
return !all_eq((char*)pa + osize, isize - osize, 0); // TODO: assumes little-endian
}
cvt_iintrinsic_checked(LLVMTrunc, check_trunc_uint, checked_trunc_uint)

JL_DLLEXPORT jl_value_t *jl_check_top_bit(jl_value_t *a)
{
jl_value_t *ty = jl_typeof(a);
if (!jl_is_primitivetype(ty))
jl_error("check_top_bit: value is not a primitive type");
if (signbitbyte(jl_data_ptr(a), jl_datatype_size(ty)))
jl_throw(jl_inexact_exception);
return a;
}

// checked arithmetic
#define check_sadd_int(a,b) \
Expand Down Expand Up @@ -911,8 +887,6 @@ bi_iintrinsic_fast(jl_LLVMFlipSign, flipsign, flipsign_int, )
#define trunc_float(pr, a) *pr = fp_select(a, trunc)
#define rint_float(pr, a) *pr = fp_select(a, rint)
#define sqrt_float(pr, a) \
if (a < 0) \
jl_throw(jl_domain_exception); \
*pr = fp_select(a, sqrt)
#define copysign_float(a, b) \
fp_select2(a, b, copysign)
Expand Down