summaryrefslogtreecommitdiff
path: root/array.c
diff options
context:
space:
mode:
Diffstat (limited to 'array.c')
-rw-r--r--array.c29
1 files changed, 19 insertions, 10 deletions
diff --git a/array.c b/array.c
index e71d938..c29c411 100644
--- a/array.c
+++ b/array.c
@@ -114,6 +114,25 @@ scalarextend(Array *a, Array *b, Array **aa, Array **bb)
}
*aa = fnSame(a);
*bb = fnSame(b);
+ }else if(a->size == 1 && b->size == 1){
+ Array *shape;
+ if(a->rank > b->rank)
+ shape = fnShape(a);
+ else
+ shape = fnShape(b);
+ *aa = fnReshape(shape, a);
+ *bb = fnReshape(shape, b);
+ freearray(shape);
+ }else if(a->size == 1 && b->size != 1){
+ Array *shape = fnShape(b);
+ *aa = fnReshape(shape, a);
+ *bb = fnSame(b);
+ freearray(shape);
+ }else if(a->size != 1 && b->size == 1){
+ Array *shape = fnShape(a);
+ *aa = fnSame(a);
+ *bb = fnReshape(shape, b);
+ freearray(shape);
}else
return 0;
return 1;
@@ -148,16 +167,6 @@ commontype(Array *a, Array *b, Array **aa, Array **bb, int forcefloat)
}else if(a->type == AtypeInt && b->type == AtypeFloat){
*aa = inttofloatarray(a);
*bb = fnSame(b);
- }else if(a->type == AtypeArray && b->type != AtypeArray){
- *aa = fnSame(a);
- *bb = allocarray(AtypeArray, 0, 1);
- (*bb)->arraydata[0] = b;
- incref(b);
- }else if(a->type != AtypeArray && b->type == AtypeArray){
- *aa = allocarray(AtypeArray, 0, 1);
- (*aa)->arraydata[0] = a;
- incref(a);
- *bb = fnSame(b);
}else
return 0;
return 1;