“计算图”(computational graph)是现代深度学习系统的基础执行引擎,提供了一种表示任意数学表达式的方法,例如用有向无环图表示的神经网络。 图中的节点表示基本操作或输入变量,边表示节点之间的中间值的依赖性。 例如,下图就是一个函数 ( 的计算图。
现在给定一个计算图,请你根据所有输入变量计算函数值及其偏导数(即梯度)。 例如,给定输入,,上述计算图获得函数值 (;并且根据微分链式法则,上图得到的梯度 ∇。
知道你已经把微积分忘了,所以这里只要求你处理几个简单的算子:加法、减法、乘法、指数(ex,即编程语言中的 exp(x) 函数)、对数(ln,即编程语言中的 log(x) 函数)和正弦函数(sin,即编程语言中的 sin(x) 函数)。
友情提醒:
常数的导数是 0;x 的导数是 1;ex 的导数还是 ex;ln 的导数是 1;sin 的导数是 cos。 回顾一下什么是偏导数:在数学中,一个多变量的函数的偏导数,就是它关于其中一个变量的导数而保持其他变量恒定。在上面的例子中,当我们对 x1 求偏导数 / 时,就将 x2 当成常数,所以得到 ln 的导数是 1,x1x2 的导数是 x2,sin 的导数是 0。 回顾一下链式法则:复合函数的导数是构成复合这有限个函数在相应点的导数的乘积,即若有 (,(,则 /。例如对 sin 求导,就得到 cos。 如果你注意观察,可以发现在计算图中,计算函数值是一个从左向右进行的计算,而计算偏导数则正好相反。
输入格式: 输入在第一行给出正整数 N(≤),为计算图中的顶点数。
以下 N 行,第 i 行给出第 i 个顶点的信息,其中 ,。第一个值是顶点的类型编号,分别为:
0 代表输入变量 1 代表加法,对应 x1 x2 2 代表减法,对应 x1−x2 3 代表乘法,对应 x1×x2 4 代表指数,对应 ex 5 代表对数,对应 ln 6 代表正弦函数,对应 sin 对于输入变量,后面会跟它的双精度浮点数值;对于单目算子,后面会跟它对应的单个变量的顶点编号(编号从 0 开始);对于双目算子,后面会跟它对应两个变量的顶点编号。
题目保证只有一个输出顶点(即没有出边的顶点,例如上图最右边的 -),且计算过程不会超过双精度浮点数的计算精度范围。
输出格式: 首先在第一行输出给定计算图的函数值。在第二行顺序输出函数对于每个变量的偏导数的值,其间以一个空格分隔,行首尾不得有多余空格。偏导数的输出顺序与输入变量的出现顺序相同。输出小数点后 3 位。
输入样例:
代码语言:javascript复制7
0 2.0
0 5.0
5 0
3 0 1
6 1
1 2 3
2 5 4
输出样例:
代码语言:javascript复制11.652
5.500 1.716
题解 将每个节点的输入节点编号存入到节点结构体中,然后先正向bfs求每个节点的输出值,然后再方向求每个节点的导数,注意求导数的时候每个节点存储的导数是其输出变量的导数,求解的时候应该按照不同路径的导数相加,同一路径上的导数相乘。 无论是正向还是方向均应按照拓扑序求解
代码语言:javascript复制#include<bits/stdc .h>
#include<cmath>
#define x first
#define y second
#define send string::npos
#define lowbit(x) (x&(-x))
#define left(x) x<<1
#define right(x) x<<1|1
using namespace std;
typedef long long ll;
typedef pair<int,int> PII;
typedef struct Node * pnode;
const int N = 1e6 10;
const int M = 3 * N;
const int INF = 0x3f3f3f3f;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
const int Mod = 1e9;
int out[N],in[N];
struct Node{
double v,f; //v代表此节点输出值,f代表输出值导数
int la,lb;
int op;
}node[N];
int head[N],cnt;
int q[N],tt = 0,hh = 0;
vector<int>s; //起始节点
struct Edge{
int v,next;
}edge[2 * M];
void add(int u,int v){
edge[cnt].v = v;
edge[cnt].next = head[u];
head[u] = cnt ;
}
double op123(int t,double a,double b){
if(t == 1)return a b;
if(t == 2)return a - b;
if(t == 3)return a * b;
}
double op456(int t,double a){
if(t == 4)return exp(a);
if(t == 5)return log(a);
if(t == 6)return sin(a);
}
void bfs(){
for(int i = 0;i < s.size();i )q[tt ] = s[i];
while(hh < tt){
int t = q[hh ];
// cout<<t<<endl;
double a = node[node[t].la].v,b = node[node[t].lb].v;
int type = node[t].op;
if(type == 1 || type == 2 || type == 3)node[t].v = op123(type,a,b);
else if(type != 0)node[t].v = op456(type,a);
// cout<<t<<" "<<type<<" "<<node[t].v<<endl;
for(int i = head[t];~i;i = edge[i].next){
int v = edge[i].v;
in[v] --;
if(in[v] == 0) //如果此处不按拓扑序,则会产生大量的重复节点
q[tt ] = v;
}
}
}
void top(int root){
hh = tt = 0;
q[tt ] = root;
while(hh < tt){
int t = q[hh ];
// cout<<t<<endl;
if(node[t].op == 1 || node[t].op == 2 || node[t].op == 3){
if(node[t].op == 1){
node[node[t].la].f = (1 * node[t].f);
node[node[t].lb].f = (1 * node[t].f);
}else if(node[t].op == 2){
node[node[t].la].f = (1 * node[t].f);
node[node[t].lb].f = (-1 * node[t].f);
}else{
node[node[t].la].f = (node[node[t].lb].v * node[t].f);
node[node[t].lb].f = (node[node[t].la].v * node[t].f);
}
out[node[t].la] --,out[node[t].lb] --;
if(out[node[t].la] == 0)q[tt ] = node[t].la;
if(out[node[t].lb] == 0)q[tt ] = node[t].lb;
}
else if(node[t].op != 0){
if(node[t].op == 4)node[node[t].la].f = (node[t].v * node[t].f);
else if(node[t].op == 5)node[node[t].la].f = (node[t].f / node[node[t].la].v) ;
else node[node[t].la].f = (cos(node[node[t].la].v) * node[t].f);
out[node[t].la] --;
if(out[node[t].la] == 0)q[tt ] = node[t].la;
}
}
}
int main(){
memset(head,-1,sizeof head);
int n,t,a,b;
double x;
cin>>n;
for(int i = 0;i < n;i ){
cin>>t;
if(t == 0){
cin>>x;
s.push_back(i);
node[i].v = x;
}
else if(t == 1 || t == 2 || t == 3){
cin>>a>>b;
add(a,i);
add(b,i);
out[a] ,out[b] ;
in[i] ,in[i] ;
node[i].la = a,node[i].lb = b;
}else if(t == 4 || t == 5 || t == 6){
cin>>a;
add(a,i);
out[a] ;
in[i] ;
node[i].la = a;
}
node[i].op = t;
}
int root = -1;
for(int i = 0;i < n;i )
if(!out[i])
root = i;
// cout<<"root:"<<root<<endl;
bfs();
node[root].f = 1; //最后一个节点的输出值导数应该为1
top(root);
printf("%.3fn%.3f",node[root].v,node[s[0]].f);
for(int i = 1;i < s.size();i )printf(" %.3f",node[s[i]].f);
return 0;
}
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/168907.html原文链接:https://javaforall.cn