summaryrefslogtreecommitdiff
path: root/functions.c
diff options
context:
space:
mode:
Diffstat (limited to 'functions.c')
-rw-r--r--functions.c65
1 files changed, 45 insertions, 20 deletions
diff --git a/functions.c b/functions.c
index 7a961a3..f5f5fac 100644
--- a/functions.c
+++ b/functions.c
@@ -524,8 +524,6 @@ fnMix(Array *right)
for(j = 0, offset = 0; offset < commonsize; j++){
for(int k = 0; index[commonrank-1-k] == a->shape[a->rank-1-k]; k++){
int nfill = commonshape->intdata[commonrank-1-k] - a->shape[a->rank-1-k];
- if(nfill)
- print("Adding %d fills\n", nfill);
while(nfill--){
memcpy(result->rawdata + (i * commonsize + offset) * datasizes[a->type],
fill->rawdata, datasizes[a->type]);
@@ -1175,15 +1173,24 @@ SCALAR_FUNCTION_2(fnNor, 0, AtypeInt,
Array *
fnTake(Array *left, Array *right)
{
+ int i;
if(left->type != AtypeInt)
throwerror(nil, EType);
if(left->rank > 1)
throwerror(nil, ERank);
+
+ if(right->rank == 0){
+ right = duparray(right);
+ right->rank = left->size;
+ right->shape = realloc(right->shape, sizeof(int) * right->rank);
+ for(i = 0; i < right->rank; i++)
+ right->shape[i] = 1;
+ }else
+ right = fnSame(right);
+
if(left->size > right->rank)
throwerror(nil, ELength);
-
- int i;
- if(left->size == right->rank)
+ else if(left->size == right->rank)
left = fnSame(left);
else{
Array *old = left;
@@ -1192,13 +1199,6 @@ fnTake(Array *left, Array *right)
left->intdata[i] = old->intdata[i];
}
- if(right->rank == 0){
- Array *leftshape = fnShape(left);
- right = fnReshape(leftshape, right);
- freearray(leftshape);
- }else
- right = fnSame(right);
-
int *shape = malloc(sizeof(int) * left->size);
int size = 1;
for(i = 0; i < left->size; i++){
@@ -1207,28 +1207,53 @@ fnTake(Array *left, Array *right)
size *= shape[i];
}
+ Array *fill = fillelement(right);
Array *result = allocarray(right->type, right->rank, size);
for(i = 0; i < right->rank; i++)
result->shape[i] = shape[i];
int *index = mallocz(sizeof(int) * left->size, 1);
- int offset;
- for(i = 0, offset = 0; offset < size; i++){
+ int fromindex;
+ for(i = 0; i < size; i++){
for(int j = left->size-1; index[j] == shape[j]; j--){
index[j] = 0;
index[j-1]++;
}
- print("Result Index: ");
- for(int j = 0; j < left->size; j++)
- print("%d ", index[j]);
- print("\n");
- /* if index is part of left vector, select those places */
+ int inside = 1;
+ fromindex = 0;
+ for(int j = 0; j < left->size && inside; j++){
+ vlong n = left->intdata[j];
+ vlong m = index[j];
+ if(n > 0 && m >= right->shape[j])
+ inside = 0;
+ else if(n < 0 && m < ((-n)-right->shape[j]))
+ inside = 0;
+ int add;
+ if(n < 0)
+ add = n + index[j] + right->shape[j];
+ else
+ add = index[j];
+ for(int k = j+1; k < right->rank; k++)
+ add *= right->shape[k];
+ fromindex += add;
+ }
- offset++;
+ if(inside)
+ memcpy(result->rawdata + i*datasizes[result->type],
+ right->rawdata + fromindex*datasizes[result->type],
+ datasizes[result->type]);
+ else
+ memcpy(result->rawdata + i*datasizes[result->type],
+ fill->rawdata, datasizes[result->type]);
+ if(result->type == AtypeArray)
+ incref(result->arraydata[i]);
index[left->size-1]++;
}
+ free(shape);
+ free(index);
+ freearray(fill);
freearray(left);
freearray(right);
return result;