diff options
author | Peter Mikkelsen <petermikkelsen10@gmail.com> | 2022-02-01 15:53:12 +0000 |
---|---|---|
committer | Peter Mikkelsen <petermikkelsen10@gmail.com> | 2022-02-01 15:53:12 +0000 |
commit | 68d8437658cd000256ab135526e09590af4bf6c5 (patch) | |
tree | 4660576ffad691fb8d902a7cc39455db7618db3d | |
parent | 116206c95ecccf49fcce426b0f353d84a17b3314 (diff) |
Redefine scalar extension
-rw-r--r-- | array.c | 29 | ||||
-rw-r--r-- | functions.c | 39 |
2 files changed, 44 insertions, 24 deletions
@@ -114,6 +114,25 @@ scalarextend(Array *a, Array *b, Array **aa, Array **bb) } *aa = fnSame(a); *bb = fnSame(b); + }else if(a->size == 1 && b->size == 1){ + Array *shape; + if(a->rank > b->rank) + shape = fnShape(a); + else + shape = fnShape(b); + *aa = fnReshape(shape, a); + *bb = fnReshape(shape, b); + freearray(shape); + }else if(a->size == 1 && b->size != 1){ + Array *shape = fnShape(b); + *aa = fnReshape(shape, a); + *bb = fnSame(b); + freearray(shape); + }else if(a->size != 1 && b->size == 1){ + Array *shape = fnShape(a); + *aa = fnSame(a); + *bb = fnReshape(shape, b); + freearray(shape); }else return 0; return 1; @@ -148,16 +167,6 @@ commontype(Array *a, Array *b, Array **aa, Array **bb, int forcefloat) }else if(a->type == AtypeInt && b->type == AtypeFloat){ *aa = inttofloatarray(a); *bb = fnSame(b); - }else if(a->type == AtypeArray && b->type != AtypeArray){ - *aa = fnSame(a); - *bb = allocarray(AtypeArray, 0, 1); - (*bb)->arraydata[0] = b; - incref(b); - }else if(a->type != AtypeArray && b->type == AtypeArray){ - *aa = allocarray(AtypeArray, 0, 1); - (*aa)->arraydata[0] = a; - incref(a); - *bb = fnSame(b); }else return 0; return 1; diff --git a/functions.c b/functions.c index 4b93b44..df90cf4 100644 --- a/functions.c +++ b/functions.c @@ -892,22 +892,33 @@ fnSelfReference1(Array *right) #define SCALAR_FUNCTION_2(name, forcefloat, restype, cases) \ Array *name(Array *left, Array *right){\ Array *leftarr, *rightarr;\ - if(!commontype(left, right, &leftarr, &rightarr, forcefloat)) throwerror(nil, EType);\ - if(!scalarextend(leftarr, rightarr, &left, &right)) throwerror(L"scalar extension fail", ERank);\ + int nested = left->type == AtypeArray || right->type == AtypeArray;\ + if(nested){\ + leftarr = fnSame(left);\ + rightarr = fnSame(right);\ + }else{\ + if(!commontype(left, right, &leftarr, &rightarr, forcefloat))\ + throwerror(nil, EType);\ + }\ + if(!scalarextend(leftarr, rightarr, &left, &right)) throwerror(L"Scalar extension fail", ERank);\ Array *res;\ - if(left->type != AtypeArray && restype != left->type)\ - res = duparrayshape(left, restype);\ - else\ - res = duparray(left);\ - for(int i = 0; i < left->size; i++)\ - switch(left->type){\ - default: throwerror(nil, EType); break;\ - case AtypeArray:\ - freearray(res->arraydata[i]);\ - res->arraydata[i] = name(left->arraydata[i], right->arraydata[i]);\ - break;\ - cases\ + if(nested){\ + res = duparrayshape(left, AtypeArray);\ + for(int i = 0; i < left->size; i++){\ + Array *l = arrayitem(left, i);\ + Array *r = arrayitem(right, i);\ + res->arraydata[i] = name(l,r);\ + freearray(l);\ + freearray(r);\ }\ + }else{\ + res = duparray(left);\ + for(int i = 0; i < left->size; i++)\ + switch(left->type){\ + default: throwerror(nil, EType); break;\ + cases\ + }\ + }\ freearray(leftarr); freearray(rightarr); freearray(left); freearray(right);\ return res;} |