diff options
author | Peter Mikkelsen <petermikkelsen10@gmail.com> | 2022-01-14 14:28:20 +0000 |
---|---|---|
committer | Peter Mikkelsen <petermikkelsen10@gmail.com> | 2022-01-14 14:28:20 +0000 |
commit | 9e586c6e7a29b8312936be3cf14a4f1685548589 (patch) | |
tree | 59f18c4b48005f1b7bf459261b180368f7bc7959 | |
parent | 3330e6ed47c71b87bb79c5e277214a36d3f2ad3b (diff) |
Implement ⍪ better, but the code is waaay too big and copy-pasty
-rw-r--r-- | functions.c | 129 |
1 files changed, 114 insertions, 15 deletions
diff --git a/functions.c b/functions.c index 636c195..5fa9305 100644 --- a/functions.c +++ b/functions.c @@ -361,21 +361,120 @@ fnRight(Array *left, Array *right) Array * fnCatenateFirst(Array *left, Array *right) { - /* not even close to being right, but it works for stranding :) */ - left = left->rank == 0 ? fnRavel(left) : fnSame(left); - right = right->rank == 0 ? fnRavel(right) : fnSame(right); - - /* assume two vectors of same type for now */ - Array *res = allocarray(left->type, 1, left->size+right->size); - res->shape[0] = left->shape[0] + right->shape[0]; - memcpy(res->rawdata, left->rawdata, datasizes[res->type]*left->size); - memcpy(res->rawdata+datasizes[res->type]*left->size, right->rawdata, datasizes[res->type]*right->size); - if(res->type == AtypeArray) - for(int i = 0; i < res->size; i++) - incref(res->arraydata[i]); - freearray(left); - freearray(right); - return res; + Array *leftarr; + Array *rightarr; + + if(left->rank == 0 && right->rank != 0){ + /* extend left to right->rank with first axis=1 */ + rightarr = fnSame(right); + Array *shape = fnShape(right); + shape->intdata[0] = 1; + leftarr = fnReshape(shape, left); + freearray(shape); + }else if(left->rank != 0 && right->rank == 0){ + /* extend right to left->rank with first axis=1 */ + leftarr = fnSame(left); + Array *shape = fnShape(left); + shape->intdata[0] = 1; + rightarr = fnReshape(shape, right); + freearray(shape); + }else if(left->rank == 0 && right->rank == 0){ + /* turn both scalars into vectors */ + leftarr = fnRavel(left); + rightarr = fnRavel(right); + }else{ + /* Check that the shapes match */ + if(left->rank == right->rank-1){ + /* extend left with unit dimension */ + Array *shape = allocarray(AtypeInt, 1, left->rank+1); + shape->intdata[0] = 1; + for(int i = 1; i < left->rank+1; i++) + shape->intdata[i] = left->shape[i-1]; + leftarr = fnReshape(shape, left); + rightarr = fnSame(right); + freearray(shape); + }else if(right->rank == left->rank-1){ + /* extend right with unit dimension */ + Array *shape = allocarray(AtypeInt, 1, right->rank+1); + shape->intdata[0] = 1; + for(int i = 1; i < right->rank+1; i++) + shape->intdata[i] = right->shape[i-1]; + rightarr = fnReshape(shape, right); + leftarr = fnSame(left); + freearray(shape); + }else if(right->rank == left->rank){ + leftarr = fnSame(left); + rightarr = fnSame(right); + }else{ + print("Ranks don't match\n"); + exits(nil); + return nil; + } + + for(int i = 1; i < leftarr->rank; i++) + if(leftarr->shape[i] != rightarr->shape[i]){ + print("Shapes don't match, lol\n"); + exits(nil); + } + } + + int type, rank, leftsize, rightsize; + if(leftarr->type == AtypeArray || rightarr->type == AtypeArray || leftarr->type != rightarr->type) + type = AtypeArray; + else + type = leftarr->type; + if(leftarr->rank > rightarr->rank) + rank = leftarr->rank; + else + rank = rightarr->rank; + + leftsize = leftarr->shape[0]; + rightsize = rightarr->shape[0]; + for(int i = 1; i < rank; i++) + leftsize *= leftarr->shape[i]; + for(int i = 1; i < rank; i++) + rightsize *= rightarr->shape[i]; + + Array *result = allocarray(type, rank, leftsize + rightsize); + int i, j; + result->shape[0] = leftarr->shape[0] + rightarr->shape[0]; + for(i = 1; i < result->rank; i++) + result->shape[i] = leftarr->shape[i]; + + /* TODO reduce duplicated code between copies from left and right */ + /* Copy data from the left array */ + for(i = 0, j = 0; i < leftarr->size; i++, j++){ + if(type == AtypeArray && leftarr->type == AtypeArray){ + result->arraydata[j] = leftarr->arraydata[i]; + incref(result->arraydata[j]); + }else if(type == AtypeArray && leftarr->type != AtypeArray){ + result->arraydata[j] = arrayitem(leftarr, i); + }else{ + memcpy( + result->rawdata + j * datasizes[type], + leftarr->rawdata + i * datasizes[type], + datasizes[type]); + } + } + + /* Copy data from the right array */ + for(i = 0; i < rightarr->size; i++, j++){ + if(type == AtypeArray && rightarr->type == AtypeArray){ + result->arraydata[j] = rightarr->arraydata[i]; + incref(result->arraydata[j]); + }else if(type == AtypeArray && rightarr->type != AtypeArray){ + result->arraydata[j] = arrayitem(rightarr, i); + }else{ + memcpy( + result->rawdata + j * datasizes[type], + rightarr->rawdata + i * datasizes[type], + datasizes[type]); + } + } + + freearray(leftarr); + freearray(rightarr); + return result; } Array * |