summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--apl9.h7
-rw-r--r--array.c11
-rw-r--r--functions.c98
3 files changed, 97 insertions, 19 deletions
diff --git a/apl9.h b/apl9.h
index 94ae192..21c1282 100644
--- a/apl9.h
+++ b/apl9.h
@@ -188,6 +188,7 @@ Array *mkscalarfloat(double);
Array *mkscalarrune(Rune);
Array *mkrunearray(Rune *);
Array *duparray(Array *);
+Array *duparrayshape(Array *, int);
int simplearray(Array *);
int simplescalar(Array *);
Array *extend(Array *, Array *);
@@ -260,6 +261,12 @@ Array *fnMaximum(Array *, Array *);
Array *fnMinimum(Array *, Array *);
Array *fnLeft(Array *, Array *);
Array *fnRight(Array *, Array *);
+Array *fnEqual(Array *, Array *);
+Array *fnNotEqual(Array *, Array *);
+Array *fnLessEqual(Array *, Array *);
+Array *fnLess(Array *, Array *);
+Array *fnGreater(Array *, Array *);
+Array *fnGreaterEqual(Array *, Array *);
Array *fnMatch(Array *, Array *);
Array *fnTake(Array *, Array *);
Array *fnIndex(Array *, Array *);
diff --git a/array.c b/array.c
index bf75a1b..b6bc87b 100644
--- a/array.c
+++ b/array.c
@@ -50,8 +50,7 @@ mkrunearray(Rune *str)
Array *
duparray(Array *a)
{
- Array *b = allocarray(a->type, a->rank, a->size);
- memcpy(b->shape, a->shape, sizeof(int) * a->rank);
+ Array *b = duparrayshape(a, a->type);
memcpy(b->rawdata, a->rawdata, datasizes[a->type]*a->size);
if(b->type == AtypeArray)
for(int i = 0; i < b->size; i++)
@@ -59,6 +58,14 @@ duparray(Array *a)
return b;
}
+Array *
+duparrayshape(Array *a, int type)
+{
+ Array *b = allocarray(type, a->rank, a->size);
+ memcpy(b->shape, a->shape, sizeof(int) * a->rank);
+ return b;
+}
+
int
simplearray(Array *a)
{
diff --git a/functions.c b/functions.c
index 43a2505..07e017b 100644
--- a/functions.c
+++ b/functions.c
@@ -79,12 +79,12 @@ fndyad dyadfunctiondefs[] = {
0, /* ⊤ */
fnLeft, /* ⊣ */
fnRight, /* ⊢ */
- 0, /* = */
- 0, /* ≠ */
- 0, /* ≤ */
- 0, /* < */
- 0, /* > */
- 0, /* ≥ */
+ fnEqual, /* = */
+ fnNotEqual, /* ≠ */
+ fnLessEqual, /* ≤ */
+ fnLess, /* < */
+ fnGreater, /* > */
+ fnGreaterEqual, /* ≥ */
fnMatch, /* ≡ */
0, /* ≢ */
0, /* ∨ */
@@ -526,12 +526,16 @@ fnTranspose(Array *right)
/* Dyadic functions */
/* macro to define dyadic scalar functions */
-#define SCALAR_FUNCTION_2(name, forcefloat, cases) \
+#define SCALAR_FUNCTION_2(name, forcefloat, restype, cases) \
Array *name(Array *left, Array *right){\
Array *leftarr, *rightarr;\
if(!commontype(left, right, &leftarr, &rightarr, forcefloat)) throwerror(nil, EType);\
if(!scalarextend(leftarr, rightarr, &left, &right)) throwerror(nil, ERank);\
- Array *res = duparray(left);\
+ Array *res;\
+ if(left->type != AtypeArray && restype != left->type)\
+ res = duparrayshape(left, restype);\
+ else\
+ res = duparray(left);\
for(int i = 0; i < left->size; i++)\
switch(res->type){\
default: throwerror(nil, EType); break;\
@@ -544,7 +548,7 @@ Array *name(Array *left, Array *right){\
freearray(leftarr); freearray(rightarr); freearray(left); freearray(right);\
return res;}
-SCALAR_FUNCTION_2(fnPlus, 0,
+SCALAR_FUNCTION_2(fnPlus, 0, left->type,
case AtypeFloat:
res->floatdata[i] += right->floatdata[i];
break;
@@ -553,7 +557,7 @@ SCALAR_FUNCTION_2(fnPlus, 0,
break;
)
-SCALAR_FUNCTION_2(fnMinus, 0,
+SCALAR_FUNCTION_2(fnMinus, 0, left->type,
case AtypeFloat:
res->floatdata[i] -= right->floatdata[i];
break;
@@ -562,7 +566,7 @@ SCALAR_FUNCTION_2(fnMinus, 0,
break;
)
-SCALAR_FUNCTION_2(fnTimes, 0,
+SCALAR_FUNCTION_2(fnTimes, 0, left->type,
case AtypeFloat:
res->floatdata[i] *= right->floatdata[i];
break;
@@ -571,13 +575,13 @@ SCALAR_FUNCTION_2(fnTimes, 0,
break;
)
-SCALAR_FUNCTION_2(fnDivide, 1,
+SCALAR_FUNCTION_2(fnDivide, 1, left->type,
case AtypeFloat:
res->floatdata[i] /= right->floatdata[i];
break;
)
-SCALAR_FUNCTION_2(fnPower, 0,
+SCALAR_FUNCTION_2(fnPower, 0, left->type,
case AtypeFloat:
res->floatdata[i] = pow(res->floatdata[i], right->floatdata[i]);
break;
@@ -586,13 +590,13 @@ SCALAR_FUNCTION_2(fnPower, 0,
break;
)
-SCALAR_FUNCTION_2(fnLogarithm, 1,
+SCALAR_FUNCTION_2(fnLogarithm, 1, left->type,
case AtypeFloat:
res->floatdata[i] = log(right->floatdata[i])/log(res->floatdata[i]);
break;
)
-SCALAR_FUNCTION_2(fnResidue, 1,
+SCALAR_FUNCTION_2(fnResidue, 1, left->type,
case AtypeFloat:
if(res->floatdata[i] == 0)
res->floatdata[i] = right->floatdata[i];
@@ -600,7 +604,7 @@ SCALAR_FUNCTION_2(fnResidue, 1,
res->floatdata[i] = right->floatdata[i] - res->floatdata[i] * floor(right->floatdata[i]/res->floatdata[i]);
)
-SCALAR_FUNCTION_2(fnMaximum, 0,
+SCALAR_FUNCTION_2(fnMaximum, 0, left->type,
case AtypeFloat:
if(res->floatdata[i] < right->floatdata[i])
res->floatdata[i] = right->floatdata[i];
@@ -609,7 +613,7 @@ SCALAR_FUNCTION_2(fnMaximum, 0,
res->intdata[i] = right->intdata[i];
)
-SCALAR_FUNCTION_2(fnMinimum, 0,
+SCALAR_FUNCTION_2(fnMinimum, 0, left->type,
case AtypeFloat:
if(res->floatdata[i] > right->floatdata[i])
res->floatdata[i] = right->floatdata[i];
@@ -634,6 +638,66 @@ fnRight(Array *left, Array *right)
return right;
}
+SCALAR_FUNCTION_2(fnEqual, 0, AtypeInt,
+ case AtypeFloat:
+ res->intdata[i] = left->floatdata[i] == right->floatdata[i];
+ break;
+ case AtypeInt:
+ res->intdata[i] = left->intdata[i] == right->intdata[i];
+ break;
+ case AtypeRune:
+ res->intdata[i] = left->runedata[i] == right->runedata[i];
+ break;
+)
+
+SCALAR_FUNCTION_2(fnNotEqual, 0, AtypeInt,
+ case AtypeFloat:
+ res->intdata[i] = left->floatdata[i] != right->floatdata[i];
+ break;
+ case AtypeInt:
+ res->intdata[i] = left->intdata[i] != right->intdata[i];
+ break;
+ case AtypeRune:
+ res->intdata[i] = left->runedata[i] != right->runedata[i];
+ break;
+)
+
+SCALAR_FUNCTION_2(fnLessEqual, 0, AtypeInt,
+ case AtypeFloat:
+ res->intdata[i] = left->floatdata[i] <= right->floatdata[i];
+ break;
+ case AtypeInt:
+ res->intdata[i] = left->intdata[i] <= right->intdata[i];
+ break;
+)
+
+SCALAR_FUNCTION_2(fnLess, 0, AtypeInt,
+ case AtypeFloat:
+ res->intdata[i] = left->floatdata[i] < right->floatdata[i];
+ break;
+ case AtypeInt:
+ res->intdata[i] = left->intdata[i] < right->intdata[i];
+ break;
+)
+
+SCALAR_FUNCTION_2(fnGreater, 0, AtypeInt,
+ case AtypeFloat:
+ res->intdata[i] = left->floatdata[i] > right->floatdata[i];
+ break;
+ case AtypeInt:
+ res->intdata[i] = left->intdata[i] > right->intdata[i];
+ break;
+)
+
+SCALAR_FUNCTION_2(fnGreaterEqual, 0, AtypeInt,
+ case AtypeFloat:
+ res->intdata[i] = left->floatdata[i] >= right->floatdata[i];
+ break;
+ case AtypeInt:
+ res->intdata[i] = left->intdata[i] >= right->intdata[i];
+ break;
+)
+
Array *
fnMatch(Array *left, Array *right)
{