summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--array.c29
-rw-r--r--functions.c39
2 files changed, 44 insertions, 24 deletions
diff --git a/array.c b/array.c
index e71d938..c29c411 100644
--- a/array.c
+++ b/array.c
@@ -114,6 +114,25 @@ scalarextend(Array *a, Array *b, Array **aa, Array **bb)
}
*aa = fnSame(a);
*bb = fnSame(b);
+ }else if(a->size == 1 && b->size == 1){
+ Array *shape;
+ if(a->rank > b->rank)
+ shape = fnShape(a);
+ else
+ shape = fnShape(b);
+ *aa = fnReshape(shape, a);
+ *bb = fnReshape(shape, b);
+ freearray(shape);
+ }else if(a->size == 1 && b->size != 1){
+ Array *shape = fnShape(b);
+ *aa = fnReshape(shape, a);
+ *bb = fnSame(b);
+ freearray(shape);
+ }else if(a->size != 1 && b->size == 1){
+ Array *shape = fnShape(a);
+ *aa = fnSame(a);
+ *bb = fnReshape(shape, b);
+ freearray(shape);
}else
return 0;
return 1;
@@ -148,16 +167,6 @@ commontype(Array *a, Array *b, Array **aa, Array **bb, int forcefloat)
}else if(a->type == AtypeInt && b->type == AtypeFloat){
*aa = inttofloatarray(a);
*bb = fnSame(b);
- }else if(a->type == AtypeArray && b->type != AtypeArray){
- *aa = fnSame(a);
- *bb = allocarray(AtypeArray, 0, 1);
- (*bb)->arraydata[0] = b;
- incref(b);
- }else if(a->type != AtypeArray && b->type == AtypeArray){
- *aa = allocarray(AtypeArray, 0, 1);
- (*aa)->arraydata[0] = a;
- incref(a);
- *bb = fnSame(b);
}else
return 0;
return 1;
diff --git a/functions.c b/functions.c
index 4b93b44..df90cf4 100644
--- a/functions.c
+++ b/functions.c
@@ -892,22 +892,33 @@ fnSelfReference1(Array *right)
#define SCALAR_FUNCTION_2(name, forcefloat, restype, cases) \
Array *name(Array *left, Array *right){\
Array *leftarr, *rightarr;\
- if(!commontype(left, right, &leftarr, &rightarr, forcefloat)) throwerror(nil, EType);\
- if(!scalarextend(leftarr, rightarr, &left, &right)) throwerror(L"scalar extension fail", ERank);\
+ int nested = left->type == AtypeArray || right->type == AtypeArray;\
+ if(nested){\
+ leftarr = fnSame(left);\
+ rightarr = fnSame(right);\
+ }else{\
+ if(!commontype(left, right, &leftarr, &rightarr, forcefloat))\
+ throwerror(nil, EType);\
+ }\
+ if(!scalarextend(leftarr, rightarr, &left, &right)) throwerror(L"Scalar extension fail", ERank);\
Array *res;\
- if(left->type != AtypeArray && restype != left->type)\
- res = duparrayshape(left, restype);\
- else\
- res = duparray(left);\
- for(int i = 0; i < left->size; i++)\
- switch(left->type){\
- default: throwerror(nil, EType); break;\
- case AtypeArray:\
- freearray(res->arraydata[i]);\
- res->arraydata[i] = name(left->arraydata[i], right->arraydata[i]);\
- break;\
- cases\
+ if(nested){\
+ res = duparrayshape(left, AtypeArray);\
+ for(int i = 0; i < left->size; i++){\
+ Array *l = arrayitem(left, i);\
+ Array *r = arrayitem(right, i);\
+ res->arraydata[i] = name(l,r);\
+ freearray(l);\
+ freearray(r);\
}\
+ }else{\
+ res = duparray(left);\
+ for(int i = 0; i < left->size; i++)\
+ switch(left->type){\
+ default: throwerror(nil, EType); break;\
+ cases\
+ }\
+ }\
freearray(leftarr); freearray(rightarr); freearray(left); freearray(right);\
return res;}