summaryrefslogtreecommitdiff
path: root/functions.c
diff options
context:
space:
mode:
Diffstat (limited to 'functions.c')
-rw-r--r--functions.c129
1 files changed, 114 insertions, 15 deletions
diff --git a/functions.c b/functions.c
index 636c195..5fa9305 100644
--- a/functions.c
+++ b/functions.c
@@ -361,21 +361,120 @@ fnRight(Array *left, Array *right)
Array *
fnCatenateFirst(Array *left, Array *right)
{
- /* not even close to being right, but it works for stranding :) */
- left = left->rank == 0 ? fnRavel(left) : fnSame(left);
- right = right->rank == 0 ? fnRavel(right) : fnSame(right);
-
- /* assume two vectors of same type for now */
- Array *res = allocarray(left->type, 1, left->size+right->size);
- res->shape[0] = left->shape[0] + right->shape[0];
- memcpy(res->rawdata, left->rawdata, datasizes[res->type]*left->size);
- memcpy(res->rawdata+datasizes[res->type]*left->size, right->rawdata, datasizes[res->type]*right->size);
- if(res->type == AtypeArray)
- for(int i = 0; i < res->size; i++)
- incref(res->arraydata[i]);
- freearray(left);
- freearray(right);
- return res;
+ Array *leftarr;
+ Array *rightarr;
+
+ if(left->rank == 0 && right->rank != 0){
+ /* extend left to right->rank with first axis=1 */
+ rightarr = fnSame(right);
+ Array *shape = fnShape(right);
+ shape->intdata[0] = 1;
+ leftarr = fnReshape(shape, left);
+ freearray(shape);
+ }else if(left->rank != 0 && right->rank == 0){
+ /* extend right to left->rank with first axis=1 */
+ leftarr = fnSame(left);
+ Array *shape = fnShape(left);
+ shape->intdata[0] = 1;
+ rightarr = fnReshape(shape, right);
+ freearray(shape);
+ }else if(left->rank == 0 && right->rank == 0){
+ /* turn both scalars into vectors */
+ leftarr = fnRavel(left);
+ rightarr = fnRavel(right);
+ }else{
+ /* Check that the shapes match */
+ if(left->rank == right->rank-1){
+ /* extend left with unit dimension */
+ Array *shape = allocarray(AtypeInt, 1, left->rank+1);
+ shape->intdata[0] = 1;
+ for(int i = 1; i < left->rank+1; i++)
+ shape->intdata[i] = left->shape[i-1];
+ leftarr = fnReshape(shape, left);
+ rightarr = fnSame(right);
+ freearray(shape);
+ }else if(right->rank == left->rank-1){
+ /* extend right with unit dimension */
+ Array *shape = allocarray(AtypeInt, 1, right->rank+1);
+ shape->intdata[0] = 1;
+ for(int i = 1; i < right->rank+1; i++)
+ shape->intdata[i] = right->shape[i-1];
+ rightarr = fnReshape(shape, right);
+ leftarr = fnSame(left);
+ freearray(shape);
+ }else if(right->rank == left->rank){
+ leftarr = fnSame(left);
+ rightarr = fnSame(right);
+ }else{
+ print("Ranks don't match\n");
+ exits(nil);
+ return nil;
+ }
+
+ for(int i = 1; i < leftarr->rank; i++)
+ if(leftarr->shape[i] != rightarr->shape[i]){
+ print("Shapes don't match, lol\n");
+ exits(nil);
+ }
+ }
+
+ int type, rank, leftsize, rightsize;
+ if(leftarr->type == AtypeArray || rightarr->type == AtypeArray || leftarr->type != rightarr->type)
+ type = AtypeArray;
+ else
+ type = leftarr->type;
+ if(leftarr->rank > rightarr->rank)
+ rank = leftarr->rank;
+ else
+ rank = rightarr->rank;
+
+ leftsize = leftarr->shape[0];
+ rightsize = rightarr->shape[0];
+ for(int i = 1; i < rank; i++)
+ leftsize *= leftarr->shape[i];
+ for(int i = 1; i < rank; i++)
+ rightsize *= rightarr->shape[i];
+
+ Array *result = allocarray(type, rank, leftsize + rightsize);
+ int i, j;
+ result->shape[0] = leftarr->shape[0] + rightarr->shape[0];
+ for(i = 1; i < result->rank; i++)
+ result->shape[i] = leftarr->shape[i];
+
+ /* TODO reduce duplicated code between copies from left and right */
+ /* Copy data from the left array */
+ for(i = 0, j = 0; i < leftarr->size; i++, j++){
+ if(type == AtypeArray && leftarr->type == AtypeArray){
+ result->arraydata[j] = leftarr->arraydata[i];
+ incref(result->arraydata[j]);
+ }else if(type == AtypeArray && leftarr->type != AtypeArray){
+ result->arraydata[j] = arrayitem(leftarr, i);
+ }else{
+ memcpy(
+ result->rawdata + j * datasizes[type],
+ leftarr->rawdata + i * datasizes[type],
+ datasizes[type]);
+ }
+ }
+
+ /* Copy data from the right array */
+ for(i = 0; i < rightarr->size; i++, j++){
+ if(type == AtypeArray && rightarr->type == AtypeArray){
+ result->arraydata[j] = rightarr->arraydata[i];
+ incref(result->arraydata[j]);
+ }else if(type == AtypeArray && rightarr->type != AtypeArray){
+ result->arraydata[j] = arrayitem(rightarr, i);
+ }else{
+ memcpy(
+ result->rawdata + j * datasizes[type],
+ rightarr->rawdata + i * datasizes[type],
+ datasizes[type]);
+ }
+ }
+
+ freearray(leftarr);
+ freearray(rightarr);
+ return result;
}
Array *