_lsh.cpp 6.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#include <Python.h>
#include <numpy/arrayobject.h>
#include "lsh.h"
#include <math.h>

/* Docstrings */
static char module_docstring[] = "This module implements fast nearest-neighbor retrieval using LSH.";
static char lsh_docstring[] = "Calculate the closest neightbours with distances given a query window.";

/* Available functions */
static PyObject* lsh_lsh(PyObject *self, PyObject *args);

/* Module specification */
static PyMethodDef module_methods[] = { { "lsh", lsh_lsh, METH_VARARGS, lsh_docstring }, { NULL, NULL, 0, NULL } };

/* Initialize the module */
#if PY_MAJOR_VERSION >= 3
#define MOD_ERROR_VAL NULL
  #define MOD_SUCCESS_VAL(val) val
  #define MOD_INIT(name) PyMODINIT_FUNC PyInit_##name(void)
  #define MOD_DEF(ob, name, doc, methods) \
          static struct PyModuleDef moduledef = { \
            PyModuleDef_HEAD_INIT, name, doc, -1, methods, }; \
          ob = PyModule_Create(&moduledef);
#else
#define MOD_ERROR_VAL
#define MOD_SUCCESS_VAL(val)
#define MOD_INIT(name) void init##name(void)
#define MOD_DEF(ob, name, doc, methods) \
          ob = Py_InitModule3(name, methods, doc);
#endif

MOD_INIT(_lsh)
{
    PyObject *m;

    MOD_DEF(m, "_lsh", module_docstring,
            module_methods)

    if (m == NULL)
        return MOD_ERROR_VAL;

    import_array();

    return MOD_SUCCESS_VAL(m);
}

static PyObject* lsh_lsh(PyObject* self, PyObject* args) {
    PyObject* data_obj = NULL;
    PyObject* query_obj = NULL;
    PyObject* weights_obj = NULL;
    double *** hashFunctions;
    double * weights = NULL;
    double*** data;
    double ** query;
    int *** candidates;
    double *** distances;
    int nrOfCandidates;
    double r;
    double a;
    double sd;
    PyArray_Descr *descr;
    descr = PyArray_DescrFromType(NPY_DOUBLE);
    PyArray_Descr *descr_float;
    descr_float = PyArray_DescrFromType(NPY_FLOAT64);

    printf("Parsing\n");

    /// Parse the input tuple
    if (!PyArg_ParseTuple(args, "OOddd|O", &data_obj, &query_obj, &r, &a, &sd, &weights_obj)) {
        return NULL;
    }

    /// Get the dimensions of the data and query
    int data_size = (long) PyArray_DIM(data_obj, 0);
    int channel_size = (long) PyArray_DIM(data_obj, 2);
    int query_size = (int) PyArray_DIM(query_obj, 0);

    printf("Loading data\n");
    /// Convert data, query and weights to C array
    npy_intp dims1[1];
    npy_intp dims2[2];
    npy_intp dims3[3];
    if (PyArray_AsCArray(&query_obj, (void **)&query, dims2, 2, descr) < 0 || PyArray_AsCArray(&data_obj, (void ***)&data, dims3, 3, descr) < 0) {
        printf("ERROR\n");
        PyErr_SetString(PyExc_TypeError, "error converting to c array");
        return NULL;
    }
    if (weights_obj != NULL)
    {
        printf("Using weights");
        if (PyArray_AsCArray(&weights_obj, (void *)&weights, dims1, 1, descr) < 0) {
            PyErr_SetString(PyExc_TypeError, "error converting weights to c array");
            return NULL;
        }
    }

    int K = ceil(log((log(0.5))/log(1-exp(-2*(0.1)*(0.1)*query_size)))/log((1-exp(-2*(0.1)*(0.1)*query_size))/(0.5)));
    int L = ceil(log(0.05)/(log(1-pow(1-exp(-2*(0.1)*(0.1)*query_size), K))));
    printf("K: %d\n", K);
    printf("L: %d\n", L);
    printf("Dim: %d\n", channel_size);

    /// Initialize output parameters
    hashFunctions = (double ***)malloc(L*sizeof(double**));
    for (int l=0;l<L;l++)
    {
        hashFunctions[l] = (double **)malloc(K*sizeof(double*));
        for (int k=0;k<K;k++)
        {
            hashFunctions[l][k] = (double *)malloc(channel_size*sizeof(double));
        }
    }

    int status = lsh(data, data_size, query_size, channel_size, query, L, K, r, a, sd, candidates, distances, hashFunctions, weights, nrOfCandidates);
    if (status) {
        PyErr_SetString(PyExc_RuntimeError, "lsh could not allocate memory");
        return NULL;
    }

    npy_intp dimscandidates[3] = {L, K, nrOfCandidates};
    printf("Number of candidates: %d\n", nrOfCandidates);
    npy_intp dims4[4] = {L, K, channel_size};

//    PyArrayObject* numpy_candidates = (PyArrayObject*)PyArray_SimpleNewFromData(1, dimscandidates, NPY_INT, (void*)&candidates);
//    PyArrayObject* numpy_distances = (PyArrayObject*)PyArray_SimpleNewFromData(1, dimsdistance, NPY_DOUBLE, (void*)&distances);
    // https://github.com/suiyun0234/scipy-master/commit/da7dfc7aad8daa7a516e43f4c7001eea7c1a707e
    // https://github.com/fjean/pymeanshift/commit/1ba90da647342184ea7df378d7f21eba257a51d9
    PyArrayObject* numpy_candidates = (PyArrayObject*)PyArray_SimpleNew(3, dimscandidates, NPY_INT);
    int* numpy_candidates_data = (int*)PyArray_DATA(numpy_candidates);
    for (int l=0;l<L;l++)
    {
        for (int k=0;k<K;k++)
        {
            memcpy(numpy_candidates_data, candidates[l][k], nrOfCandidates*sizeof(int));
            numpy_candidates_data += nrOfCandidates;
        }
    }

    PyArrayObject* numpy_distances = (PyArrayObject*)PyArray_SimpleNew(3, dimscandidates, NPY_DOUBLE);
    double* numpy_distances_data = (double*)PyArray_DATA(numpy_distances);
    for (int l=0;l<L;l++)
    {
        for (int k=0;k<K;k++)
        {
            memcpy(numpy_distances_data, distances[l][k], nrOfCandidates*sizeof(double));
            numpy_distances_data += nrOfCandidates;
        }
    }

    PyArrayObject* numpy_hash_functions = (PyArrayObject*)PyArray_SimpleNew(3, dims4, NPY_DOUBLE);
    double* numpy_hash_functions_data = (double*)PyArray_DATA(numpy_hash_functions);
    for (int l=0;l<L;l++)
    {
        for (int k=0;k<K;k++)
        {
            memcpy(numpy_hash_functions_data, hashFunctions[l][k], channel_size*sizeof(double));
            numpy_hash_functions_data += channel_size;
        }
    }
    PyObject* ret = Py_BuildValue("NNN", PyArray_Return(numpy_candidates), PyArray_Return(numpy_distances), PyArray_Return(numpy_hash_functions));
//    Py_XDECREF(data_obj);
//    Py_XDECREF(query_obj);
//    Py_XDECREF(hash_obj);
//    Py_XDECREF(weights);
//    Py_XDECREF(data);
//    Py_XDECREF(query);
//    Py_XDECREF(descr);
//    Py_XDECREF(descr_float);
//    Py_XDECREF(query);
//    Py_XDECREF(numpy_candidates);
//    Py_XDECREF(numpy_distances);
//    Py_XDECREF(TEST);
//    free(candidates);
//    free(distances);
//    for (int l=0;l<L;l++)
//    {
//        for (int k=0;k<K;k++)
//        {
//            for (int t=0;t<query_size;t++)
//            {
//                free(hashFunctions[l][k][t]);
//            }
//            free(hashFunctions[l][k]);
//        }
//        free(hashFunctions[l]);
//    }
//    free(hashFunctions);
    return ret;
}