einsumfuncs.py 8.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import numpy as np
from numpy.compat import basestring

from .core import blockwise, asarray, einsum_lookup
from ..utils import derived_from

einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
einsum_symbols_set = set(einsum_symbols)


def chunk_einsum(*operands, **kwargs):
    subscripts = kwargs.pop("subscripts")
    ncontract_inds = kwargs.pop("ncontract_inds")
    dtype = kwargs.pop("kernel_dtype")
    einsum = einsum_lookup.dispatch(type(operands[0]))
    chunk = einsum(subscripts, *operands, dtype=dtype, **kwargs)

    # Avoid concatenate=True in blockwise by adding 1's
    # for the contracted dimensions
    return chunk.reshape(chunk.shape + (1,) * ncontract_inds)


# This function duplicates numpy's _parse_einsum_input() function
# See https://github.com/numpy/numpy/blob/master/LICENSE.txt
# or NUMPY_LICENSE.txt within this directory
def parse_einsum_input(operands):
    """
    A reproduction of numpy's _parse_einsum_input()
    which in itself is a reproduction of
    c side einsum parsing in python.

    Returns
    -------
    input_strings : str
        Parsed input strings
    output_string : str
        Parsed output string
    operands : list of array_like
        The operands to use in the numpy contraction
    Examples
    --------
    The operand list is simplified to reduce printing:
    >> a = np.random.rand(4, 4)
    >> b = np.random.rand(4, 4, 4)
    >> __parse_einsum_input(('...a,...a->...', a, b))
    ('za,xza', 'xz', [a, b])
    >> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
    ('za,xza', 'xz', [a, b])
    """

    if len(operands) == 0:
        raise ValueError("No input operands")

    if isinstance(operands[0], basestring):
        subscripts = operands[0].replace(" ", "")
        operands = [asarray(o) for o in operands[1:]]

        # Ensure all characters are valid
        for s in subscripts:
            if s in ".,->":
                continue
            if s not in einsum_symbols_set:
                raise ValueError("Character %s is not a valid symbol." % s)

    else:
        tmp_operands = list(operands)
        operand_list = []
        subscript_list = []
        for p in range(len(operands) // 2):
            operand_list.append(tmp_operands.pop(0))
            subscript_list.append(tmp_operands.pop(0))

        output_list = tmp_operands[-1] if len(tmp_operands) else None
        operands = [asarray(v) for v in operand_list]
        subscripts = ""
        last = len(subscript_list) - 1
        for num, sub in enumerate(subscript_list):
            for s in sub:
                if s is Ellipsis:
                    subscripts += "..."
                elif isinstance(s, int):
                    subscripts += einsum_symbols[s]
                else:
                    raise TypeError(
                        "For this input type lists must contain "
                        "either int or Ellipsis"
                    )
            if num != last:
                subscripts += ","

        if output_list is not None:
            subscripts += "->"
            for s in output_list:
                if s is Ellipsis:
                    subscripts += "..."
                elif isinstance(s, int):
                    subscripts += einsum_symbols[s]
                else:
                    raise TypeError(
                        "For this input type lists must contain "
                        "either int or Ellipsis"
                    )
    # Check for proper "->"
    if ("-" in subscripts) or (">" in subscripts):
        invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
        if invalid or (subscripts.count("->") != 1):
            raise ValueError("Subscripts can only contain one '->'.")

    # Parse ellipses
    if "." in subscripts:
        used = subscripts.replace(".", "").replace(",", "").replace("->", "")
        unused = list(einsum_symbols_set - set(used))
        ellipse_inds = "".join(unused)
        longest = 0

        if "->" in subscripts:
            input_tmp, output_sub = subscripts.split("->")
            split_subscripts = input_tmp.split(",")
            out_sub = True
        else:
            split_subscripts = subscripts.split(",")
            out_sub = False

        for num, sub in enumerate(split_subscripts):
            if "." in sub:
                if (sub.count(".") != 3) or (sub.count("...") != 1):
                    raise ValueError("Invalid Ellipses.")

                # Take into account numerical values
                if operands[num].shape == ():
                    ellipse_count = 0
                else:
                    ellipse_count = max(operands[num].ndim, 1)
                    ellipse_count -= len(sub) - 3

                if ellipse_count > longest:
                    longest = ellipse_count

                if ellipse_count < 0:
                    raise ValueError("Ellipses lengths do not match.")
                elif ellipse_count == 0:
                    split_subscripts[num] = sub.replace("...", "")
                else:
                    rep_inds = ellipse_inds[-ellipse_count:]
                    split_subscripts[num] = sub.replace("...", rep_inds)

        subscripts = ",".join(split_subscripts)
        if longest == 0:
            out_ellipse = ""
        else:
            out_ellipse = ellipse_inds[-longest:]

        if out_sub:
            subscripts += "->" + output_sub.replace("...", out_ellipse)
        else:
            # Special care for outputless ellipses
            output_subscript = ""
            tmp_subscripts = subscripts.replace(",", "")
            for s in sorted(set(tmp_subscripts)):
                if s not in einsum_symbols_set:
                    raise ValueError("Character %s is not a valid symbol." % s)
                if tmp_subscripts.count(s) == 1:
                    output_subscript += s
            normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse)))

            subscripts += "->" + out_ellipse + normal_inds

    # Build output string if does not exist
    if "->" in subscripts:
        input_subscripts, output_subscript = subscripts.split("->")
    else:
        input_subscripts = subscripts
        # Build output subscripts
        tmp_subscripts = subscripts.replace(",", "")
        output_subscript = ""
        for s in sorted(set(tmp_subscripts)):
            if s not in einsum_symbols_set:
                raise ValueError("Character %s is not a valid symbol." % s)
            if tmp_subscripts.count(s) == 1:
                output_subscript += s

    # Make sure output subscripts are in the input
    for char in output_subscript:
        if char not in input_subscripts:
            raise ValueError("Output character %s did not appear in the input" % char)

    # Make sure number operands is equivalent to the number of terms
    if len(input_subscripts.split(",")) != len(operands):
        raise ValueError(
            "Number of einsum subscripts must be equal to the number of operands."
        )

    return (input_subscripts, output_subscript, operands)


@derived_from(np)
def einsum(*operands, **kwargs):
    dtype = kwargs.pop("dtype", None)
    optimize = kwargs.pop("optimize", False)
    split_every = kwargs.pop("split_every", None)

    einsum_dtype = dtype

    inputs, outputs, ops = parse_einsum_input(operands)
    subscripts = "->".join((inputs, outputs))

    # Infer the output dtype from operands
    if dtype is None:
        dtype = np.result_type(*[o.dtype for o in ops])

    if optimize is not False:
        # Avoid computation of dask arrays within np.einsum_path
        # by passing in small numpy arrays broadcasted
        # up to the right shape
        fake_ops = [np.broadcast_to(o.dtype.type(0), shape=o.shape) for o in ops]
        optimize, _ = np.einsum_path(subscripts, *fake_ops, optimize=optimize)

    inputs = [tuple(i) for i in inputs.split(",")]

    # Set of all indices
    all_inds = set(a for i in inputs for a in i)

    # Which indices are contracted?
    contract_inds = all_inds - set(outputs)
    ncontract_inds = len(contract_inds)

    # Introduce the contracted indices into the blockwise product
    # so that we get numpy arrays, not lists
    result = blockwise(
        chunk_einsum,
        tuple(outputs) + tuple(contract_inds),
        *(a for ap in zip(ops, inputs) for a in ap),
        # blockwise parameters
        adjust_chunks={ind: 1 for ind in contract_inds},
        dtype=dtype,
        # np.einsum parameters
        subscripts=subscripts,
        kernel_dtype=einsum_dtype,
        ncontract_inds=ncontract_inds,
        optimize=optimize,
        **kwargs
    )

    # Now reduce over any extra contraction dimensions
    if ncontract_inds > 0:
        size = len(outputs)
        return result.sum(
            axis=list(range(size, size + ncontract_inds)), split_every=split_every
        )

    return result