Skip to content

Commit 1b75678

Browse files
committed
Generalize findnz and reduction to AbstractSparseVectors
1 parent e081db6 commit 1b75678

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

src/sparsevector.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ SparseVector(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti} =
4848
# union of such a view and a SparseVector so we define an alias for such a union as well
4949
const SparseColumnView{Tv,Ti} = SubArray{Tv,1,<:AbstractSparseMatrixCSC{Tv,Ti},Tuple{Base.Slice{Base.OneTo{Int}},Int},false}
5050
const SparseVectorView{Tv,Ti} = SubArray{Tv,1,<:AbstractSparseVector{Tv,Ti},Tuple{Base.Slice{Base.OneTo{Int}}},false}
51-
const SparseVectorUnion{Tv,Ti} = Union{SparseVector{Tv,Ti}, SparseColumnView{Tv,Ti}, SparseVectorView{Tv,Ti}}
51+
const SparseVectorUnion{Tv,Ti} = Union{AbstractSparseVector{Tv,Ti}, SparseColumnView{Tv,Ti}, SparseVectorView{Tv,Ti}}
5252
const AdjOrTransSparseVectorUnion{Tv,Ti} = LinearAlgebra.AdjOrTrans{Tv, <:SparseVectorUnion{Tv,Ti}}
5353

5454
### Basic properties
@@ -779,7 +779,7 @@ findall(p::Base.Fix2{typeof(in)}, x::SparseVector{<:Any,Ti}) where {Ti} =
779779
findnz(x::SparseVector)
780780
781781
Return a tuple `(I, V)` where `I` is the indices of the stored ("structurally non-zero")
782-
values in sparse vector `x` and `V` is a vector of the values.
782+
values in sparse vector-like `x` and `V` is a vector of the values.
783783
784784
# Examples
785785
```jldoctest
@@ -794,7 +794,7 @@ julia> findnz(x)
794794
([1, 4, 6, 8], [1, 2, 4, 3])
795795
```
796796
"""
797-
function findnz(x::SparseVector{Tv,Ti}) where {Tv,Ti}
797+
function findnz(x::SparseVectorUnion{Tv,Ti}) where {Tv,Ti}
798798
numnz = nnz(x)
799799

800800
I = Vector{Ti}(undef, numnz)
@@ -1405,10 +1405,10 @@ for (fun, mode) in [(:+, 1), (:-, 1), (:*, 0), (:min, 2), (:max, 2)]
14051405
end
14061406

14071407
### Reduction
1408-
Base.reducedim_initarray(A::AbstractSparseVector, region, v0, ::Type{R}) where {R} =
1408+
Base.reducedim_initarray(A::SparseVectorUnion, region, v0, ::Type{R}) where {R} =
14091409
fill!(Array{R}(undef, Base.to_shape(Base.reduced_indices(A, region))), v0)
14101410

1411-
function Base._mapreduce(f, op, ::IndexCartesian, A::AbstractSparseVector{T}) where {T}
1411+
function Base._mapreduce(f, op, ::IndexCartesian, A::SparseVectorUnion{T}) where {T}
14121412
isempty(A) && return Base.mapreduce_empty(f, op, T)
14131413
z = nnz(A)
14141414
rest, ini = if z == 0
@@ -1419,7 +1419,7 @@ function Base._mapreduce(f, op, ::IndexCartesian, A::AbstractSparseVector{T}) wh
14191419
_mapreducezeros(f, op, T, rest, ini)
14201420
end
14211421

1422-
function Base.mapreducedim!(f, op, R::AbstractVector, A::AbstractSparseVector)
1422+
function Base.mapreducedim!(f, op, R::AbstractVector, A::SparseVectorUnion)
14231423
# dim1 reduction could be safely replaced with a mapreduce
14241424
if length(R) == 1
14251425
I = firstindex(R)

test/sparsevector.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,10 @@ end
337337
@test findall(!iszero, xc) == findall(!iszero, fc)
338338
@test findnz(xc) == ([2, 3, 5], [1.25, 0, -0.75])
339339
end
340+
let Xc = spdiagm(spv_x1)
341+
@test all(isempty, findnz(@view Xc[:,1]))
342+
@test findnz(@view Xc[:,2]) == ([2], [1.25])
343+
end
340344
end
341345
### Array manipulation
342346

0 commit comments

Comments
 (0)